Repository: G-Research/spark-extension
Branch: master
Commit: 65c3dda4a96b
Files: 128
Total size: 857.8 KB
Directory structure:
gitextract_5zjc6xfa/
├── .github/
│ ├── actions/
│ │ ├── build-whl/
│ │ │ └── action.yml
│ │ ├── check-compat/
│ │ │ └── action.yml
│ │ ├── prime-caches/
│ │ │ └── action.yml
│ │ ├── test-jvm/
│ │ │ └── action.yml
│ │ ├── test-python/
│ │ │ └── action.yml
│ │ └── test-release/
│ │ └── action.yml
│ ├── dependabot.yml
│ ├── show-spark-versions.sh
│ └── workflows/
│ ├── build-jvm.yml
│ ├── build-python.yml
│ ├── build-snapshots.yml
│ ├── check.yml
│ ├── ci.yml
│ ├── clear-caches.yaml
│ ├── prepare-release.yml
│ ├── prime-caches.yml
│ ├── publish-release.yml
│ ├── publish-snapshot.yml
│ ├── test-jvm.yml
│ ├── test-python.yml
│ ├── test-release.yml
│ ├── test-results.yml
│ └── test-snapshots.yml
├── .gitignore
├── .scalafmt.conf
├── CHANGELOG.md
├── CONDITIONAL.md
├── DIFF.md
├── GROUPS.md
├── HISTOGRAM.md
├── LICENSE
├── MAINTAINERS.md
├── PARQUET.md
├── PARTITIONING.md
├── PYSPARK-DEPS.md
├── README.md
├── RELEASE.md
├── ROW_NUMBER.md
├── SECURITY.md
├── build-whl.sh
├── bump-version.sh
├── examples/
│ └── python-deps/
│ ├── Dockerfile
│ ├── docker-compose.yml
│ └── example.py
├── pom.xml
├── python/
│ ├── README.md
│ ├── gresearch/
│ │ ├── __init__.py
│ │ └── spark/
│ │ ├── __init__.py
│ │ ├── diff/
│ │ │ ├── __init__.py
│ │ │ └── comparator/
│ │ │ └── __init__.py
│ │ └── parquet/
│ │ └── __init__.py
│ ├── pyproject.toml
│ ├── pyspark/
│ │ └── jars/
│ │ └── .gitignore
│ ├── setup.py
│ └── test/
│ ├── __init__.py
│ ├── spark_common.py
│ ├── test_diff.py
│ ├── test_histogram.py
│ ├── test_job_description.py
│ ├── test_jvm.py
│ ├── test_package.py
│ ├── test_parquet.py
│ └── test_row_number.py
├── release.sh
├── set-version.sh
├── src/
│ ├── main/
│ │ ├── scala/
│ │ │ └── uk/
│ │ │ └── co/
│ │ │ └── gresearch/
│ │ │ ├── package.scala
│ │ │ └── spark/
│ │ │ ├── BuildVersion.scala
│ │ │ ├── Histogram.scala
│ │ │ ├── RowNumbers.scala
│ │ │ ├── SparkVersion.scala
│ │ │ ├── UnpersistHandle.scala
│ │ │ ├── diff/
│ │ │ │ ├── App.scala
│ │ │ │ ├── Diff.scala
│ │ │ │ ├── DiffComparators.scala
│ │ │ │ ├── DiffOptions.scala
│ │ │ │ ├── comparator/
│ │ │ │ │ ├── DefaultDiffComparator.scala
│ │ │ │ │ ├── DiffComparator.scala
│ │ │ │ │ ├── DurationDiffComparator.scala
│ │ │ │ │ ├── EpsilonDiffComparator.scala
│ │ │ │ │ ├── EquivDiffComparator.scala
│ │ │ │ │ ├── MapDiffComparator.scala
│ │ │ │ │ ├── NullSafeEqualDiffComparator.scala
│ │ │ │ │ ├── TypedDiffComparator.scala
│ │ │ │ │ └── WhitespaceDiffComparator.scala
│ │ │ │ └── package.scala
│ │ │ ├── group/
│ │ │ │ └── package.scala
│ │ │ ├── package.scala
│ │ │ └── parquet/
│ │ │ ├── ParquetMetaDataUtil.scala
│ │ │ └── package.scala
│ │ ├── scala-spark-3.2/
│ │ │ └── uk/
│ │ │ └── co/
│ │ │ └── gresearch/
│ │ │ └── spark/
│ │ │ └── parquet/
│ │ │ └── SplitFile.scala
│ │ ├── scala-spark-3.3/
│ │ │ └── uk/
│ │ │ └── co/
│ │ │ └── gresearch/
│ │ │ └── spark/
│ │ │ └── parquet/
│ │ │ └── SplitFile.scala
│ │ ├── scala-spark-3.5/
│ │ │ ├── org/
│ │ │ │ └── apache/
│ │ │ │ └── spark/
│ │ │ │ └── sql/
│ │ │ │ └── extension/
│ │ │ │ └── package.scala
│ │ │ └── uk/
│ │ │ └── co/
│ │ │ └── gresearch/
│ │ │ └── spark/
│ │ │ └── Backticks.scala
│ │ └── scala-spark-4.0/
│ │ ├── org/
│ │ │ └── apache/
│ │ │ └── spark/
│ │ │ └── sql/
│ │ │ └── extension/
│ │ │ └── package.scala
│ │ └── uk/
│ │ └── co/
│ │ └── gresearch/
│ │ └── spark/
│ │ ├── Backticks.scala
│ │ └── parquet/
│ │ └── SplitFile.scala
│ └── test/
│ ├── files/
│ │ ├── encrypted1.parquet
│ │ ├── encrypted2.parquet
│ │ ├── nested.parquet
│ │ └── test.parquet/
│ │ ├── file1.parquet
│ │ └── file2.parquet
│ ├── java/
│ │ └── uk/
│ │ └── co/
│ │ └── gresearch/
│ │ └── test/
│ │ ├── SparkJavaTests.java
│ │ └── diff/
│ │ ├── DiffJavaTests.java
│ │ ├── JavaValue.java
│ │ └── JavaValueAs.java
│ ├── resources/
│ │ ├── log4j.properties
│ │ └── log4j2.properties
│ ├── scala/
│ │ └── uk/
│ │ └── co/
│ │ └── gresearch/
│ │ ├── spark/
│ │ │ ├── GroupBySuite.scala
│ │ │ ├── HistogramSuite.scala
│ │ │ ├── SparkSuite.scala
│ │ │ ├── SparkTestSession.scala
│ │ │ ├── WritePartitionedSuite.scala
│ │ │ ├── diff/
│ │ │ │ ├── AppSuite.scala
│ │ │ │ ├── DiffComparatorSuite.scala
│ │ │ │ ├── DiffOptionsSuite.scala
│ │ │ │ ├── DiffSuite.scala
│ │ │ │ └── examples/
│ │ │ │ └── Examples.scala
│ │ │ ├── group/
│ │ │ │ └── GroupSuite.scala
│ │ │ ├── parquet/
│ │ │ │ └── ParquetSuite.scala
│ │ │ └── test/
│ │ │ └── package.scala
│ │ └── test/
│ │ ├── ClasspathSuite.scala
│ │ ├── Spec.scala
│ │ └── Suite.scala
│ ├── scala-spark-3/
│ │ └── uk/
│ │ └── co/
│ │ └── gresearch/
│ │ └── spark/
│ │ └── SparkSuiteHelper.scala
│ └── scala-spark-4/
│ └── uk/
│ └── co/
│ └── gresearch/
│ └── spark/
│ └── SparkSuiteHelper.scala
├── test-release.py
├── test-release.scala
└── test-release.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/actions/build-whl/action.yml
================================================
name: 'Build Whl'
author: 'EnricoMi'
description: 'A GitHub Action that builds pyspark-extension package'
inputs:
spark-version:
description: Spark version, e.g. 3.4.0, 3.4.0-SNAPSHOT, or 4.0.0-preview1
required: true
scala-version:
description: Scala version, e.g. 2.12.15
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
java-compat-version:
description: Java compatibility version, e.g. 8
required: true
python-version:
description: Python version, e.g. 3.8
required: true
runs:
using: 'composite'
steps:
- name: Fetch Binaries Artifact
uses: actions/download-artifact@v4
with:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
shell: bash
- name: Make this work with PySpark preview versions
if: contains(inputs.spark-version, 'preview')
run: |
sed -i -e 's/f"\(pyspark~=.*\)"/f"\1.dev1"/' -e 's/f"\({spark_compat_version}.0\)"/"${{ inputs.spark-version }}"/g' python/setup.py
git diff python/setup.py
shell: bash
- name: Restore Maven packages cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-
- name: Setup JDK ${{ inputs.java-compat-version }}
uses: actions/setup-java@v4
with:
java-version: ${{ inputs.java-compat-version }}
distribution: 'zulu'
- name: Fetch Release Test Dependencies
run: |
# Fetch Release Test Dependencies
echo "::group::mvn dependency:get"
mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests
echo "::endgroup::"
shell: bash
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install Python dependencies
run: |
# Install Python dependencies
echo "::group::mvn compile"
python -m pip install --upgrade pip build twine
echo "::endgroup::"
shell: bash
- name: Build whl
run: |
# Build whl
echo "::group::build-whl.sh"
./build-whl.sh
echo "::endgroup::"
shell: bash
- name: Test whl
run: |
# Test whl
echo "::group::test-release.py"
twine check python/dist/*
# .dev1 allows this to work with preview versions
pip install python/dist/*.whl "pyspark~=${{ inputs.spark-compat-version }}.0.dev1"
python test-release.py
echo "::endgroup::"
shell: bash
- name: Upload whl
uses: actions/upload-artifact@v4
with:
name: Whl (Spark ${{ inputs.spark-compat-version }} Scala ${{ inputs.scala-compat-version }})
path: |
python/dist/*.whl
- name: Build whl with mvn
env:
JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
run: |
# Build whl with mvn
rm -rf target python/dist python/pyspark_extension.egg-info pyspark/jars/*.jar
echo "::group::build-whl.sh"
./build-whl.sh
echo "::endgroup::"
shell: bash
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/actions/check-compat/action.yml
================================================
name: 'Check'
author: 'EnricoMi'
description: 'A GitHub Action that checks compatibility of spark-extension'
inputs:
spark-version:
description: Spark version, e.g. 3.4.0 or 3.4.0-SNAPSHOT
required: true
scala-version:
description: Scala version, e.g. 2.12.15
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
package-version:
description: Spark-Extension version to check against
required: true
runs:
using: 'composite'
steps:
- name: Fetch Binaries Artifact
uses: actions/download-artifact@v4
with:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
shell: bash
- name: Restore Maven packages cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-
- name: Setup JDK 1.8
uses: actions/setup-java@v4
with:
java-version: '8'
distribution: 'zulu'
- name: Install Checker
run: |
# Install Checker
echo "::group::apt update install"
sudo apt update
sudo apt install japi-compliance-checker
echo "::endgroup::"
shell: bash
- name: Release exists
id: exists
continue-on-error: true
run: |
# Release exists
curl --head --fail https://repo1.maven.org/maven2/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar
shell: bash
- name: Fetch package
if: steps.exists.outcome == 'success'
run: |
# Fetch package
echo "::group::mvn dependency:get"
mvn dependency:get -Dtransitive=false -DremoteRepositories -Dartifact=uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:${{ inputs.package-version }}-${{ inputs.spark-compat-version }}
echo "::endgroup::"
shell: bash
- name: Check
if: steps.exists.outcome == 'success'
continue-on-error: ${{ github.ref == 'refs/heads/master' }}
run: |
# Check
echo "::group::japi-compliance-checker"
ls -lah ~/.m2/repository/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar target/spark-extension*.jar
japi-compliance-checker ~/.m2/repository/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar target/spark-extension*.jar
echo "::endgroup::"
shell: bash
- name: Upload Report
uses: actions/upload-artifact@v4
if: always() && steps.exists.outcome == 'success'
with:
name: Compat-Report-${{ inputs.spark-compat-version }}
path: compat_reports/spark-extension/*
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/actions/prime-caches/action.yml
================================================
name: 'Prime caches'
author: 'EnricoMi'
description: 'A GitHub Action that primes caches'
inputs:
spark-version:
description: Spark version, e.g. 3.4.0 or 3.4.0-SNAPSHOT
required: true
scala-version:
description: Scala version, e.g. 2.12.15
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
java-compat-version:
description: Java compatibility version, e.g. 8
required: true
hadoop-version:
description: Hadoop version, e.g. 2.7 or 2
required: true
runs:
using: 'composite'
steps:
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
shell: bash
- name: Check Maven packages cache
id: mvn-build-cache
uses: actions/cache/restore@v4
with:
lookup-only: true
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
- name: Check Spark Binaries cache
id: spark-binaries-cache
uses: actions/cache/restore@v4
with:
lookup-only: true
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
- name: Prepare priming caches
id: setup
run: |
# Prepare priming caches
if [[ "${{ inputs.spark-version }}" == *"-SNAPSHOT" ]] || [[ -z "${{ steps.mvn-build-cache.outputs.cache-hit }}" ]]; then
echo "prime-mvn-cache=true" >> "$GITHUB_ENV"
echo "prime-some-cache=true" >> "$GITHUB_ENV"
fi;
if [[ "${{ inputs.spark-version }}" == *"-SNAPSHOT" ]] || [[ -z "${{ steps.spark-binaries-cache.outputs.cache-hit }}" ]]; then
echo "prime-spark-cache=true" >> "$GITHUB_ENV"
echo "prime-some-cache=true" >> "$GITHUB_ENV"
fi;
shell: bash
- name: Setup JDK ${{ inputs.java-compat-version }}
if: env.prime-some-cache
uses: actions/setup-java@v4
with:
java-version: ${{ inputs.java-compat-version }}
distribution: 'zulu'
- name: Build
if: env.prime-mvn-cache
env:
JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
run: |
# Build
echo "::group::mvn dependency:go-offline"
mvn --batch-mode dependency:go-offline
echo "::endgroup::"
shell: bash
- name: Save Maven packages cache
if: env.prime-mvn-cache
uses: actions/cache/save@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}-${{ github.run_id }}
- name: Setup Spark Binaries
if: env.prime-spark-cache && ! contains(inputs.spark-version, '-SNAPSHOT')
env:
SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz
run: |
wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download" -O - | tar -xzC "${{ runner.temp }}"
archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark"
shell: bash
- name: Save Spark Binaries cache
if: env.prime-spark-cache && ! contains(inputs.spark-version, '-SNAPSHOT')
uses: actions/cache/save@v4
with:
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}-${{ github.run_id }}
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/actions/test-jvm/action.yml
================================================
name: 'Test JVM'
author: 'EnricoMi'
description: 'A GitHub Action that tests JVM spark-extension'
inputs:
spark-version:
description: Spark version, e.g. 3.4.0, 3.4.0-SNAPSHOT or 4.0.0-preview1
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
spark-archive-url:
description: The URL to download the Spark binary distribution
required: false
scala-version:
description: Scala version, e.g. 2.12.15
required: true
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
hadoop-version:
description: Hadoop version, e.g. 2.7 or 2
required: true
java-compat-version:
description: Java compatibility version, e.g. 8
required: true
runs:
using: 'composite'
steps:
- name: Fetch Binaries Artifact
uses: actions/download-artifact@v4
with:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
shell: bash
- name: Restore Spark Binaries cache
if: github.event_name != 'schedule' && ! contains(inputs.spark-version, '-SNAPSHOT')
uses: actions/cache/restore@v4
with:
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
restore-keys: |
${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
- name: Setup Spark Binaries
if: ( ! contains(inputs.spark-version, '-SNAPSHOT') )
env:
SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz
run: |
# Setup Spark Binaries
if [[ ! -e ~/spark ]]
then
url="${{ inputs.spark-archive-url }}"
wget --progress=dot:giga "${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}" -O - | tar -xzC "${{ runner.temp }}"
archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark"
fi
echo "SPARK_HOME=$(cd ~/spark; pwd)" >> $GITHUB_ENV
shell: bash
- name: Restore Maven packages cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-
- name: Setup JDK ${{ inputs.java-compat-version }}
uses: actions/setup-java@v4
with:
java-version: ${{ inputs.java-compat-version }}
distribution: 'zulu'
- name: Scala and Java Tests
env:
JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
run: |
# Scala and Java Tests
echo "::group::mvn test"
mvn --batch-mode --update-snapshots -Dspotless.check.skip test integration-test
echo "::endgroup::"
shell: bash
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v4
with:
name: JVM Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }})
path: |
target/surefire-*reports/*.xml
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/actions/test-python/action.yml
================================================
name: 'Test Python'
author: 'EnricoMi'
description: 'A GitHub Action that tests Python spark-extension'
# pyspark is not available for snapshots or scala other than 2.12
# we would have to compile spark from sources for this, not worth it
# so this action only works with scala 2.12 and non-snapshot spark versions
inputs:
spark-version:
description: Spark version, e.g. 3.4.0 or 4.0.0-preview1
required: true
scala-version:
description: Scala version, e.g. 2.12.15
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
spark-archive-url:
description: The URL to download the Spark binary distribution
required: false
spark-package-repo:
description: The URL of an alternate maven repository to fetch Spark packages
required: false
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
java-compat-version:
description: Java compatibility version, e.g. 8
required: true
hadoop-version:
description: Hadoop version, e.g. 2.7 or 2
required: true
python-version:
description: Python version, e.g. 3.8
required: true
runs:
using: 'composite'
steps:
- name: Fetch Binaries Artifact
uses: actions/download-artifact@v4
with:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
SPARK_EXTENSION_VERSION=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g")
echo "SPARK_EXTENSION_VERSION=$SPARK_EXTENSION_VERSION" | tee -a "$GITHUB_ENV"
shell: bash
- name: Make this work with PySpark preview versions
if: contains(inputs.spark-version, 'preview')
run: |
sed -i -e 's/\({spark_compat_version}.0\)"/\1.dev1"/' python/setup.py
git diff python/setup.py
shell: bash
- name: Restore Spark Binaries cache
if: github.event_name != 'schedule' && ( startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')
uses: actions/cache/restore@v4
with:
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
restore-keys: |
${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
- name: Setup Spark Binaries
if: ( startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')
env:
SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz
run: |
# Setup Spark Binaries
if [[ ! -e ~/spark ]]
then
url="${{ inputs.spark-archive-url }}"
wget --progress=dot:giga "${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}" -O - | tar -xzC "${{ runner.temp }}"
archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark"
fi
echo "SPARK_BIN_HOME=$(cd ~/spark; pwd)" >> $GITHUB_ENV
shell: bash
- name: Restore Maven packages cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-
- name: Setup JDK ${{ inputs.java-compat-version }}
uses: actions/setup-java@v4
with:
java-version: ${{ inputs.java-compat-version }}
distribution: 'zulu'
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install Python dependencies
run: |
# Install Python dependencies
echo "::group::pip install"
python -m venv .pytest-venv
.pytest-venv/bin/python -m pip install --upgrade pip
.pytest-venv/bin/pip install pypandoc
.pytest-venv/bin/pip install -e python/[test]
echo "::endgroup::"
PYSPARK_HOME=$(.pytest-venv/bin/python -c "import os; import pyspark; print(os.path.dirname(pyspark.__file__))")
PYSPARK_BIN_HOME="$(cd ".pytest-venv/"; pwd)"
PYSPARK_PYTHON="$PYSPARK_BIN_HOME/bin/python"
echo "PYSPARK_HOME=$PYSPARK_HOME" | tee -a "$GITHUB_ENV"
echo "PYSPARK_BIN_HOME=$PYSPARK_BIN_HOME" | tee -a "$GITHUB_ENV"
echo "PYSPARK_PYTHON=$PYSPARK_PYTHON" | tee -a "$GITHUB_ENV"
shell: bash
- name: Prepare Poetry tests
run: |
# Prepare Poetry tests
echo "::group::Prepare poetry tests"
# install poetry in venv
python -m venv .poetry-venv
.poetry-venv/bin/python -m pip install poetry
# env var needed by poetry tests
echo "POETRY_PYTHON=$PWD/.poetry-venv/bin/python" | tee -a "$GITHUB_ENV"
# clone example poetry project
git clone https://github.com/Textualize/rich.git .rich
cd .rich
git reset --hard 20024635c06c22879fd2fd1e380ec4cccd9935dd
# env var needed by poetry tests
echo "RICH_SOURCES=$PWD" | tee -a "$GITHUB_ENV"
echo "::endgroup::"
shell: bash
- name: Python Unit Tests
env:
SPARK_HOME: ${{ env.PYSPARK_HOME }}
PYTHONPATH: python/test
run: |
.pytest-venv/bin/python -m pytest python/test --junit-xml test-results/pytest-$(date +%s.%N)-$RANDOM.xml
shell: bash
- name: Install Spark Extension
run: |
# Install Spark Extension
echo "::group::mvn install"
mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip
echo "::endgroup::"
shell: bash
- name: Start Spark Connect
id: spark-connect
if: ( contains('3.4,3.5', inputs.spark-compat-version) && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')
env:
SPARK_HOME: ${{ env.SPARK_BIN_HOME }}
CONNECT_GRPC_BINDING_ADDRESS: 127.0.0.1
CONNECT_GRPC_BINDING_PORT: 15002
run: |
# Start Spark Connect
for attempt in {1..10}; do
$SPARK_HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }} --repositories "${{ inputs.spark-package-repo }}"
sleep 10
for log in $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out; do
echo "::group::Spark Connect server log: $log"
eoc="EOC-$RANDOM"
echo "::stop-commands::$eoc"
cat "$log" || true
echo "::$eoc::"
echo "::endgroup::"
done
if netstat -an | grep 15002; then
break;
fi
echo "::warning title=Starting Spark Connect server failed::Attempt #$attempt to start Spark Connect server failed"
$SPARK_HOME/sbin/stop-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }}
sleep 5
done
if ! netstat -an | grep 15002; then
echo "::error title=Starting Spark Connect server failed::All attempts to start Spark Connect server failed"
exit 1
fi
shell: bash
- name: Python Unit Tests (Spark Connect)
if: steps.spark-connect.outcome == 'success'
env:
SPARK_HOME: ${{ env.PYSPARK_HOME }}
PYTHONPATH: python/test
TEST_SPARK_CONNECT_SERVER: sc://127.0.0.1:15002
run: |
# Python Unit Tests (Spark Connect)
echo "::group::pip install"
# .dev1 allows this to work with preview versions
.pytest-venv/bin/pip install "pyspark[connect]~=${{ inputs.spark-compat-version }}.0.dev1"
echo "::endgroup::"
.pytest-venv/bin/python -m pytest python/test --junit-xml test-results-connect/pytest-$(date +%s.%N)-$RANDOM.xml
shell: bash
- name: Stop Spark Connect
if: always() && steps.spark-connect.outcome == 'success'
env:
SPARK_HOME: ${{ env.SPARK_BIN_HOME }}
run: |
# Stop Spark Connect
$SPARK_HOME/sbin/stop-connect-server.sh
for log in $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out; do
echo "::group::Spark Connect server log: $log"
eoc="EOC-$RANDOM"
echo "::stop-commands::$eoc"
cat "$log" || true
echo "::$eoc::"
echo "::endgroup::"
done
shell: bash
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v4
with:
name: Python Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})
path: |
test-results/*.xml
test-results-connect/*.xml
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/actions/test-release/action.yml
================================================
name: 'Test Release'
author: 'EnricoMi'
description: 'A GitHub Action that tests spark-extension release'
# pyspark is not available for snapshots or scala other than 2.12
# we would have to compile spark from sources for this, not worth it
# so this action only works with scala 2.12 and non-snapshot spark versions
inputs:
spark-version:
description: Spark version, e.g. 3.4.0 or 4.0.0-preview1
required: true
scala-version:
description: Scala version, e.g. 2.12.15
required: true
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
spark-archive-url:
description: The URL to download the Spark binary distribution
required: false
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
java-compat-version:
description: Java compatibility version, e.g. 8
required: true
hadoop-version:
description: Hadoop version, e.g. 2.7 or 2
required: true
python-version:
description: Python version, e.g. 3.8
default: ''
required: false
runs:
using: 'composite'
steps:
- name: Fetch Binaries Artifact
uses: actions/download-artifact@v4
with:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
- name: Set versions in pom.xml
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff
SPARK_EXTENSION_VERSION=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g")
echo "SPARK_EXTENSION_VERSION=$SPARK_EXTENSION_VERSION" | tee -a "$GITHUB_ENV"
shell: bash
- name: Restore Spark Binaries cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
restore-keys: |
${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
- name: Setup Spark Binaries
env:
SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz
run: |
# Setup Spark Binaries
if [[ ! -e ~/spark ]]
then
url="${{ inputs.spark-archive-url }}"
wget --progress=dot:giga "${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}" -O - | tar -xzC "${{ runner.temp }}"
archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark"
fi
echo "SPARK_BIN_HOME=$(cd ~/spark; pwd)" >> $GITHUB_ENV
shell: bash
- name: Restore Maven packages cache
if: github.event_name != 'schedule'
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-
- name: Setup JDK ${{ inputs.java-compat-version }}
uses: actions/setup-java@v4
with:
java-version: ${{ inputs.java-compat-version }}
distribution: 'zulu'
- name: Diff App test
env:
SPARK_HOME: ${{ env.SPARK_BIN_HOME }}
run: |
# Diff App test
echo "::group::spark-submit"
$SPARK_HOME/bin/spark-submit --packages com.github.scopt:scopt_${{ inputs.scala-compat-version }}:4.1.0 target/spark-extension_*.jar --format parquet --id id src/test/files/test.parquet/file1.parquet src/test/files/test.parquet/file2.parquet diff.parquet
echo
echo "::endgroup::"
echo "::group::spark-shell"
$SPARK_HOME/bin/spark-shell <<< 'val df = spark.read.parquet("diff.parquet").orderBy($"id").groupBy($"diff").count; df.show; if (df.count != 2) sys.exit(1)'
echo
echo "::endgroup::"
shell: bash
- name: Install Spark Extension
run: |
# Install Spark Extension
echo "::group::mvn install"
mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip
echo "::endgroup::"
shell: bash
- name: Fetch Release Test Dependencies
run: |
# Fetch Release Test Dependencies
echo "::group::mvn dependency:get"
mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests
echo "::endgroup::"
shell: bash
- name: Scala Release Test
env:
SPARK_HOME: ${{ env.SPARK_BIN_HOME }}
run: |
# Scala Release Test
echo "::group::spark-shell"
$SPARK_BIN_HOME/bin/spark-shell --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION --jars ~/.m2/repository/org/apache/parquet/parquet-hadoop/1.16.0/parquet-hadoop-1.16.0-tests.jar < test-release.scala
echo
echo "::endgroup::"
shell: bash
- name: Setup Python
uses: actions/setup-python@v5
if: inputs.python-version != ''
with:
python-version: ${{ inputs.python-version }}
- name: Python Release Test
if: inputs.python-version != ''
env:
SPARK_HOME: ${{ env.SPARK_BIN_HOME }}
run: |
# Python Release Test
echo "::group::spark-submit"
$SPARK_BIN_HOME/bin/spark-submit --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION test-release.py
echo
echo "::endgroup::"
shell: bash
- name: Fetch Whl Artifact
if: inputs.python-version != ''
uses: actions/download-artifact@v4
with:
name: Whl (Spark ${{ inputs.spark-compat-version }} Scala ${{ inputs.scala-compat-version }})
path: .
- name: Install Python dependencies
if: inputs.python-version != ''
run: |
# Install Python dependencies
echo "::group::pip install"
python -m venv .pytest-venv
.pytest-venv/bin/python -m pip install --upgrade pip
.pytest-venv/bin/pip install pypandoc
.pytest-venv/bin/pip install $(ls pyspark_extension-*.whl)[test]
echo "::endgroup::"
PYSPARK_HOME=$(.pytest-venv/bin/python -c "import os; import pyspark; print(os.path.dirname(pyspark.__file__))")
PYSPARK_BIN_HOME="$(cd ".pytest-venv/"; pwd)"
PYSPARK_PYTHON="$PYSPARK_BIN_HOME/bin/python"
echo "PYSPARK_HOME=$PYSPARK_HOME" | tee -a "$GITHUB_ENV"
echo "PYSPARK_BIN_HOME=$PYSPARK_BIN_HOME" | tee -a "$GITHUB_ENV"
echo "PYSPARK_PYTHON=$PYSPARK_PYTHON" | tee -a "$GITHUB_ENV"
shell: bash
- name: PySpark Release Test
if: inputs.python-version != ''
run: |
.pytest-venv/bin/python3 test-release.py
shell: bash
- name: Python Integration Tests
if: inputs.python-version != ''
env:
SPARK_HOME: ${{ env.PYSPARK_HOME }}
PYTHONPATH: python:python/test
run: |
# Python Integration Tests
source .pytest-venv/bin/activate
find python/test -name 'test*.py' > tests
while read test
do
echo "::group::spark-submit $test"
if ! $PYSPARK_BIN_HOME/bin/spark-submit --master "local[2]" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results-submit
then
state="fail"
fi
echo
echo "::endgroup::"
done < tests
if [[ "$state" == "fail" ]]; then exit 1; fi
shell: bash
- name: Upload Test Results
if: always() && inputs.python-version != ''
uses: actions/upload-artifact@v4
with:
name: Python Release Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})
path: |
test-results-submit/*.xml
branding:
icon: 'check-circle'
color: 'green'
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
- package-ecosystem: "maven"
directory: "/"
schedule:
interval: "daily"
================================================
FILE: .github/show-spark-versions.sh
================================================
#!/bin/bash
base=$(cd "$(dirname "$0")"; pwd)
grep -- "-version" "$base"/workflows/prime-caches.yml | sed -e "s/ -//g" -e "s/ //g" -e "s/'//g" | grep -v -e "matrix" -e "]" | while read line
do
IFS=":" read var compat_version <<< "$line"
if [[ "$var" == "spark-compat-version" ]]
then
while read line
do
IFS=":" read var patch_version <<< "$line"
if [[ "$var" == "spark-patch-version" ]]
then
echo -n "spark-version: $compat_version.$patch_version"
read line
if [[ "$line" == "spark-snapshot-version:true" ]]
then
echo "-SNAPSHOT"
else
echo
fi
break
fi
done
fi
done > "$base"/workflows/prime-caches.yml.tmp
grep spark-version "$base"/workflows/*.yml "$base"/workflows/prime-caches.yml.tmp | cut -d : -f 2- | sed -e "s/^[ -]*//" -e "s/'//g" -e 's/{"params": {"//g' -e 's/params: {//g' -e 's/"//g' -e "s/,.*//" | grep "^spark-version" | grep -v "matrix" | sort | uniq
================================================
FILE: .github/workflows/build-jvm.yml
================================================
name: Build JVM
on:
workflow_call:
jobs:
build:
name: Build (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- spark-version: '3.2.4'
spark-compat-version: '3.2'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '2.7'
- spark-version: '3.3.4'
spark-compat-version: '3.3'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '3.4.4'
spark-compat-version: '3.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '3.5.8'
spark-compat-version: '3.5'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '3.2.4'
spark-compat-version: '3.2'
scala-compat-version: '2.13'
scala-version: '2.13.5'
java-compat-version: '8'
hadoop-version: '3.2'
- spark-version: '3.3.4'
spark-compat-version: '3.3'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '3.4.4'
spark-compat-version: '3.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '3.5.8'
spark-compat-version: '3.5'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-version: '4.0.2'
spark-compat-version: '4.0'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
hadoop-version: '3'
- spark-version: '4.1.1'
spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
hadoop-version: '3'
- spark-version: '4.2.0-preview3'
spark-compat-version: '4.2'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
hadoop-version: '3'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build
uses: ./.github/actions/build
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
hadoop-version: ${{ matrix.hadoop-version }}
================================================
FILE: .github/workflows/build-python.yml
================================================
name: Build Python
on:
workflow_call:
jobs:
# pyspark<4 is not available for snapshots or scala other than 2.12
whl:
name: Build whl (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- spark-compat-version: '3.2'
spark-version: '3.2.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
python-version: '3.9'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
python-version: '3.9'
- spark-compat-version: '3.4'
spark-version: '3.4.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
python-version: '3.9'
- spark-compat-version: '3.5'
spark-version: '3.5.8'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
python-version: '3.9'
- spark-compat-version: '4.0'
spark-version: '4.0.2'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
python-version: '3.9'
- spark-version: '4.1.1'
spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.10'
- spark-version: '4.2.0-preview3'
spark-compat-version: '4.2'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.10'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build
uses: ./.github/actions/build-whl
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
python-version: ${{ matrix.python-version }}
================================================
FILE: .github/workflows/build-snapshots.yml
================================================
name: Build Snapshots
on:
workflow_call:
jobs:
build:
name: Build (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- spark-compat-version: '3.2'
spark-version: '3.2.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
- spark-compat-version: '3.3'
spark-version: '3.3.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
- spark-compat-version: '3.4'
spark-version: '3.4.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
- spark-compat-version: '3.5'
spark-version: '3.5.9-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
- spark-compat-version: '3.2'
spark-version: '3.2.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.5'
java-compat-version: '8'
- spark-compat-version: '3.3'
spark-version: '3.3.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '3.4'
spark-version: '3.4.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '3.5'
spark-version: '3.5.9-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '4.0'
spark-version: '4.0.3-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
- spark-compat-version: '4.1'
spark-version: '4.1.2-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
- spark-compat-version: '4.2'
spark-version: '4.2.0-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build
uses: ./.github/actions/build
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
================================================
FILE: .github/workflows/check.yml
================================================
name: Check
on:
workflow_call:
jobs:
lint:
name: Scala lint
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup JDK ${{ inputs.java-compat-version }}
uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'zulu'
- name: Check
id: check
run: |
mvn --batch-mode --update-snapshots spotless:check
shell: bash
- name: Changes
if: failure() && steps.check.outcome == 'failure'
run: |
mvn --batch-mode --update-snapshots spotless:apply
git diff
shell: bash
config:
name: Configure compat
runs-on: ubuntu-latest
outputs:
major-version: ${{ steps.versions.outputs.major-version }}
release-version: ${{ steps.versions.outputs.release-version }}
release-major-version: ${{ steps.versions.outputs.release-major-version }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get versions
id: versions
run: |
version=$(grep -m1 version pom.xml | sed -e "s/<[^>]*>//g" -e "s/ //g")
echo "version: $version"
echo "major-version: ${version/.*/}"
echo "version=$version" >> "$GITHUB_OUTPUT"
echo "major-version=${version/.*/}" >> "$GITHUB_OUTPUT"
release_version=$(git tag | grep "^v" | sort --version-sort | tail -n1 | sed "s/^v//")
echo "release-version: $release_version"
echo "release-major-version: ${release_version/.*/}"
echo "release-version=$release_version" >> "$GITHUB_OUTPUT"
echo "release-major-version=${release_version/.*/}" >> "$GITHUB_OUTPUT"
shell: bash
compat:
name: Compat (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }})
needs: config
runs-on: ubuntu-latest
if: needs.config.outputs.major-version == needs.config.outputs.release-major-version
strategy:
fail-fast: false
matrix:
include:
- spark-compat-version: '3.2'
spark-version: '3.2.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
- spark-compat-version: '3.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
spark-version: '3.4.4'
- spark-compat-version: '3.5'
scala-compat-version: '2.12'
scala-version: '2.12.18'
spark-version: '3.5.8'
- spark-compat-version: '4.0'
scala-compat-version: '2.13'
scala-version: '2.13.16'
spark-version: '4.0.2'
- spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
spark-version: '4.1.1'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Check
uses: ./.github/actions/check-compat
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
scala-compat-version: ${{ matrix.scala-compat-version }}
package-version: ${{ needs.config.outputs.release-version }}
================================================
FILE: .github/workflows/ci.yml
================================================
name: CI
on:
schedule:
- cron: '0 8 */10 * *'
push:
branches:
- 'master'
tags:
- '*'
merge_group:
pull_request:
workflow_dispatch:
jobs:
event_file:
name: "Event File"
runs-on: ubuntu-latest
steps:
- name: Upload
uses: actions/upload-artifact@v4
with:
name: Event File
path: ${{ github.event_path }}
build-jvm:
name: "Build JVM"
uses: "./.github/workflows/build-jvm.yml"
build-snapshots:
name: "Build Snapshots"
uses: "./.github/workflows/build-snapshots.yml"
build-python:
name: "Build Python"
needs: build-jvm
uses: "./.github/workflows/build-python.yml"
test-jvm:
name: "Test JVM"
needs: build-jvm
uses: "./.github/workflows/test-jvm.yml"
test-python:
name: "Test Python"
needs: build-jvm
uses: "./.github/workflows/test-python.yml"
test-snapshots-jvm:
name: "Test Snapshots"
needs: build-snapshots
uses: "./.github/workflows/test-snapshots.yml"
test-release:
name: "Test Release"
needs: build-jvm
uses: "./.github/workflows/test-release.yml"
check:
name: "Check"
needs: build-jvm
uses: "./.github/workflows/check.yml"
# A single job that succeeds if all jobs listed under 'needs' succeed.
# This allows to configure a single job as a required check.
# The 'needed' jobs then can be changed through pull-requests.
test_success:
name: "Test success"
if: always()
runs-on: ubuntu-latest
# the if clauses below have to reflect the number of jobs listed here
needs: [build-jvm, build-python, test-jvm, test-python, test-release]
env:
RESULTS: ${{ join(needs.*.result, ',') }}
steps:
- name: "Success"
# we expect all required jobs to have success result
if: env.RESULTS == 'success,success,success,success,success'
run: true
shell: bash
- name: "Failure"
# we expect all required jobs to have success result, fail otherwise
if: env.RESULTS != 'success,success,success,success,success'
run: false
shell: bash
================================================
FILE: .github/workflows/clear-caches.yaml
================================================
name: Clear caches
on:
workflow_dispatch:
permissions:
actions: write
jobs:
clear-cache:
runs-on: ubuntu-latest
steps:
- name: Clear caches
uses: actions/github-script@v7
with:
script: |
const caches = await github.paginate(
github.rest.actions.getActionsCacheList.endpoint.merge({
owner: context.repo.owner,
repo: context.repo.repo,
})
)
for (const cache of caches) {
console.log(cache)
github.rest.actions.deleteActionsCacheById({
owner: context.repo.owner,
repo: context.repo.repo,
cache_id: cache.id,
})
}
================================================
FILE: .github/workflows/prepare-release.yml
================================================
name: Prepare release
on:
workflow_dispatch:
inputs:
github_release_latest:
description: 'Make the created GitHub release the latest'
required: false
default: true
type: boolean
jobs:
get-version:
name: Get version
runs-on: ubuntu-latest
outputs:
release-tag: ${{ steps.versions.outputs.release-tag }}
is-snapshot: ${{ steps.versions.outputs.is-snapshot }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get versions
id: versions
run: |
# get release version
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g" -e "s/-SNAPSHOT//" -e "s/-[0-9.]+//g")
is_snapshot=$(if grep -q ".*-SNAPSHOT" pom.xml; then echo "true"; else echo "false"; fi)
# share versions
echo "release-tag=v${version}" >> "$GITHUB_OUTPUT"
echo "is-snapshot=$is_snapshot" >> "$GITHUB_OUTPUT"
prepare-release:
name: Prepare release
runs-on: ubuntu-latest
if: ( ! github.event.repository.fork )
needs: get-version
# secrets are provided by environment
environment:
name: tagged
url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}'
steps:
- name: Create GitHub App token
uses: actions/create-github-app-token@v2
id: app-token
with:
app-id: ${{ vars.APP_ID }}
private-key: ${{ secrets.PRIVATE_KEY }}
# required to push to a branch
permission-contents: write
- name: Get GitHub App User ID
id: get-user-id
run: echo "user-id=$(gh api "/users/${{ steps.app-token.outputs.app-slug }}[bot]" --jq .id)" >> "$GITHUB_OUTPUT"
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
- name: Checkout code
uses: actions/checkout@v4
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
- name: Check branch setup
run: |
# Check branch setup
if [[ "$GITHUB_REF" != "refs/heads/master" ]] && [[ "$GITHUB_REF" != "refs/heads/master-"* ]]
then
echo "This workflow must be run on master or master-* branch, not $GITHUB_REF"
exit 1
fi
- name: Tag and bump version
if: needs.get-version.outputs.is-snapshot
run: |
# check for unreleased entry in CHANGELOG.md
readarray -t changes < <(grep -A 100 "^## \[UNRELEASED\] - YYYY-MM-DD" CHANGELOG.md | grep -B 100 --max-count=1 -E "^## \[[0-9.]+\]" | grep "^-")
if [ ${#changes[@]} -eq 0 ]
then
echo "Did not find any changes in CHANGELOG.md under '## [UNRELEASED] - YYYY-MM-DD'"
exit 1
fi
# get latest and release version
latest=$(grep --max-count=1 ".*" README.md | sed -E -e "s/\s*<[^>]+>//g" -e "s/-[0-9.]+//g")
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g" -e "s/-SNAPSHOT//" -e "s/-[0-9.]+//g")
# update changlog
echo "Releasing ${#changes[@]} changes as version $version:"
for (( i=0; i<${#changes[@]}; i++ )); do echo "${changes[$i]}" ; done
sed -i "s/## \[UNRELEASED\] - YYYY-MM-DD/## [$version] - $(date +%Y-%m-%d)/" CHANGELOG.md
sed -i -e "s/$latest-/$version-/g" -e "s/$latest\./$version./g" README.md PYSPARK-DEPS.md python/README.md
./set-version.sh $version
# configure git so we can commit changes
git config --global user.name '${{ steps.app-token.outputs.app-slug }}[bot]'
git config --global user.email '${{ steps.get-user-id.outputs.user-id }}+${{ steps.app-token.outputs.app-slug }}[bot]@users.noreply.github.com'
# commit changes to local repo
echo "Committing release to local git"
git add pom.xml python/setup.py CHANGELOG.md README.md PYSPARK-DEPS.md python/README.md
git commit -m "Releasing $version"
git tag -a "v${version}" -m "Release v${version}"
# bump version
# define function to bump version
function next_version {
local version=$1
local branch=$2
patch=${version/*./}
majmin=${version%.${patch}}
if [[ $branch == "master" ]]
then
# minor version bump
if [[ $version != *".0" ]]
then
echo "version is patch version, should be M.m.0: $version" >&2
exit 1
fi
maj=${version/.*/}
min=${majmin#${maj}.}
next=${maj}.$((min+1)).0
echo "$next"
else
# patch version bump
next=${majmin}.$((patch+1))
echo "$next"
fi
}
# get next version
pkg_version="${version/-*/}"
branch=$(git rev-parse --abbrev-ref HEAD)
next_pkg_version="$(next_version "$pkg_version" "$branch")"
# bump the version
echo "Bump version to $next_pkg_version"
./set-version.sh $next_pkg_version-SNAPSHOT
# commit changes to local repo
echo "Committing release to local git"
git commit -a -m "Post-release version bump to $next_pkg_version"
# push all commits and tag to origin
echo "Pushing release commit and tag to origin"
git push origin "$GITHUB_REF_NAME" "v${version}" --tags
# NOTE: This push will not trigger a CI as we are using GITHUB_TOKEN to push
# More info on: https://docs.github.com/en/actions/using-workflows/triggering-a-workflow#triggering-a-workflow-from-a-workflow
github-release:
name: Create GitHub release
runs-on: ubuntu-latest
needs:
- get-version
- prepare-release
permissions:
contents: write # required to create release
steps:
- name: Checkout release tag
uses: actions/checkout@v4
with:
ref: ${{ needs.get-version.outputs.release-tag }}
- name: Extract release notes
id: release-notes
run: |
awk '/^## /{if(seen==1)exit; seen++} seen' CHANGELOG.md > ./release-notes.txt
# Grab release name
name=$(grep -m 1 "^## " CHANGELOG.md | sed "s/^## //")
echo "release_name=$name" >> $GITHUB_OUTPUT
# provide release notes file path as output
echo "release_notes_path=release-notes.txt" >> $GITHUB_OUTPUT
- name: Publish GitHub release
uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5
id: github-release
with:
name: ${{ steps.release-notes.outputs.release_name }}
bodyFile: ${{ steps.release-notes.outputs.release_notes_path }}
makeLatest: ${{ inputs.github_release_latest }}
tag: ${{ needs.get-version.outputs.release-tag }}
token: ${{ github.token }}
================================================
FILE: .github/workflows/prime-caches.yml
================================================
name: Prime caches
on:
workflow_dispatch:
jobs:
prime:
name: Spark ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}${{ matrix.spark-snapshot-version && '-SNAPSHOT' }} Scala ${{ matrix.scala-version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
# keep in-sync with .github/workflows/test-jvm.yml
matrix:
include:
- spark-compat-version: '3.2'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '4'
hadoop-version: '2.7'
- spark-compat-version: '3.3'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '4'
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
spark-patch-version: '4'
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.12'
scala-version: '2.12.18'
spark-patch-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.2'
scala-compat-version: '2.13'
scala-version: '2.13.5'
spark-patch-version: '4'
hadoop-version: '3.2'
- spark-compat-version: '3.3'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '4'
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '4'
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '8'
hadoop-version: '3'
- spark-compat-version: '4.0'
scala-compat-version: '2.13'
scala-version: '2.13.16'
spark-patch-version: '2'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
spark-patch-version: '1'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '4.2'
scala-compat-version: '2.13'
scala-version: '2.13.18'
spark-patch-version: '0-preview3'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '3.2'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '2.7'
- spark-compat-version: '3.3'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.12'
scala-version: '2.12.18'
spark-patch-version: '9'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '3.2'
scala-compat-version: '2.13'
scala-version: '2.13.5'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '3.2'
- spark-compat-version: '3.3'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '5'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '9'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '4.0'
scala-compat-version: '2.13'
scala-version: '2.13.16'
spark-patch-version: '3'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
spark-patch-version: '2'
spark-snapshot-version: true
hadoop-version: '3'
- spark-compat-version: '4.2'
scala-compat-version: '2.13'
scala-version: '2.13.18'
spark-patch-version: '0'
spark-snapshot-version: true
hadoop-version: '3'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Prime caches
uses: ./.github/actions/prime-caches
with:
spark-version: ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}${{ matrix.spark-snapshot-version && '-SNAPSHOT' }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
scala-compat-version: ${{ matrix.scala-compat-version }}
hadoop-version: ${{ matrix.hadoop-version }}
java-compat-version: '8'
================================================
FILE: .github/workflows/publish-release.yml
================================================
name: Publish release
on:
workflow_dispatch:
inputs:
versions:
required: true
type: string
description: 'Example: {"include": [{"params": {"spark-version": "4.0.0","scala-version": "2.13.16"}}]}'
default: |
{
"include": [
{"params": {"spark-version": "3.2.4", "scala-version": "2.12.15", "java-compat-version": "8"}},
{"params": {"spark-version": "3.3.4", "scala-version": "2.12.15", "java-compat-version": "8"}},
{"params": {"spark-version": "3.4.4", "scala-version": "2.12.17", "java-compat-version": "8"}},
{"params": {"spark-version": "3.5.8", "scala-version": "2.12.18", "java-compat-version": "8"}},
{"params": {"spark-version": "3.2.4", "scala-version": "2.13.5", "java-compat-version": "8"}},
{"params": {"spark-version": "3.3.4", "scala-version": "2.13.8", "java-compat-version": "8"}},
{"params": {"spark-version": "3.4.4", "scala-version": "2.13.8", "java-compat-version": "8"}},
{"params": {"spark-version": "3.5.8", "scala-version": "2.13.8", "java-compat-version": "8"}},
{"params": {"spark-version": "4.0.2", "scala-version": "2.13.16", "java-compat-version": "17"}},
{"params": {"spark-version": "4.1.1", "scala-version": "2.13.17", "java-compat-version": "17"}}
]
}
env:
# PySpark 3 versions only work with Python 3.9
PYTHON_VERSION: "3.9"
jobs:
get-version:
name: Get version
runs-on: ubuntu-latest
outputs:
release-tag: ${{ steps.versions.outputs.release-tag }}
is-snapshot: ${{ steps.versions.outputs.is-snapshot }}
steps:
- name: Checkout release tag
uses: actions/checkout@v4
- name: Get versions
id: versions
run: |
# get release version
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g" -e "s/-SNAPSHOT//" -e "s/-[0-9.]+//g")
is_snapshot=$(if grep -q ".*-SNAPSHOT" pom.xml; then echo "true"; else echo "false"; fi)
# share versions
echo "release-tag=v${version}" >> "$GITHUB_OUTPUT"
echo "is-snapshot=$is_snapshot" >> "$GITHUB_OUTPUT"
- name: Check tag setup
run: |
# Check tag setup
if [[ "$GITHUB_REF" != "refs/tags/v"* ]]
then
echo "This workflow must be run on a tag, not $GITHUB_REF"
exit 1
fi
if [ "${{ steps.versions.outputs.is-snapshot }}" == "true" ]
then
echo "This is a tagged SNAPSHOT version. This is not allowed for release!"
exit 1
fi
if [ "${{ github.ref_name }}" != "${{ steps.versions.outputs.release-tag }}" ]
then
echo "The version in the pom.xml is ${{ steps.versions.outputs.release-tag }}"
echo "This tag is ${{ github.ref_name }}, which is different!"
exit 1
fi
- name: Show matrix
run: |
echo '${{ github.event.inputs.versions }}' | jq .
maven-release:
name: Publish maven release (Spark ${{ matrix.params.spark-version }}, Scala ${{ matrix.params.scala-version }})
runs-on: ubuntu-latest
needs: get-version
if: ( ! github.event.repository.fork )
# secrets are provided by environment
environment:
name: release
# a different URL for each point in the matrix, but the same URLs accross commits
url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}&spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&package=maven'
permissions: {}
strategy:
fail-fast: false
matrix: ${{ fromJson(github.event.inputs.versions) }}
steps:
- name: Checkout release tag
uses: actions/checkout@v4
- name: Set up JDK and publish to Maven Central
uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12 # v4.7.0
with:
java-version: ${{ matrix.params.java-compat-version }}
distribution: 'corretto'
server-id: central
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.MAVEN_GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Inspect GPG
run: gpg -k
- name: Restore Maven packages cache
id: cache-maven
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-
- name: Publish maven artifacts
id: publish-maven
run: |
./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}
mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
env:
MAVEN_USERNAME: ${{ secrets.MAVEN_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.MAVEN_PASSWORD }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.MAVEN_GPG_PASSPHRASE}}
pypi-release:
name: Publish PyPi release (Spark ${{ matrix.params.spark-version }}, Scala ${{ matrix.params.scala-version }})
runs-on: ubuntu-latest
needs: get-version
if: ( ! github.event.repository.fork )
# secrets are provided by environment
environment:
name: release
# a different URL for each point in the matrix, but the same URLs accross commits
url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}&spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&package=pypi'
permissions:
id-token: write # required for PiPy publish
strategy:
fail-fast: false
matrix: ${{ fromJson(github.event.inputs.versions) }}
steps:
- name: Checkout release tag
uses: actions/checkout@v4
- name: Set up JDK
uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12 # v4.7.0
with:
java-version: ${{ matrix.params.java-compat-version }}
distribution: 'corretto'
- uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Restore Maven packages cache
id: cache-maven
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-
- name: Build maven artifacts
id: maven
if: startsWith(matrix.params.spark-version, '3.') && startsWith(matrix.params.scala-version, '2.12.') || startsWith(matrix.params.spark-version, '4.') && startsWith(matrix.params.scala-version, '2.13.')
run: |
./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}
mvn clean package -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
- name: Prepare PyPi package
id: prepare-pypi-package
if: steps.maven.outcome == 'success'
run: |
./build-whl.sh
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
if: steps.prepare-pypi-package.outcome == 'success'
with:
packages-dir: python/dist
skip-existing: true
verbose: true
================================================
FILE: .github/workflows/publish-snapshot.yml
================================================
name: Publish snapshot
on:
workflow_dispatch:
push:
branches: ["master"]
env:
PYTHON_VERSION: "3.10"
jobs:
check-version:
name: Check SNAPSHOT version
if: ( ! github.event.repository.fork )
runs-on: ubuntu-latest
permissions: {}
outputs:
is-snapshot: ${{ steps.check.outputs.is-snapshot }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if this is a SNAPSHOT version
id: check
run: |
# check is snapshot version
if grep -q ".*-SNAPSHOT" pom.xml
then
echo "Version in pom IS a SNAPSHOT version"
echo "is-snapshot=true" >> "$GITHUB_OUTPUT"
else
echo "Version in pom is NOT a SNAPSHOT version"
echo "is-snapshot=false" >> "$GITHUB_OUTPUT"
fi
snapshot:
name: Snapshot Spark ${{ matrix.params.spark-version }} Scala ${{ matrix.params.scala-version }}
needs: check-version
# when we release from master, this workflow will see a commit that does not have a SNAPSHOT version
# we want this workflow to skip over that commit
if: needs.check-version.outputs.is-snapshot == 'true'
runs-on: ubuntu-latest
# secrets are provided by environment
environment:
name: snapshot
# a different URL for each point in the matrix, but the same URLs accross commits
url: 'https://github.com/G-Research/spark-extension?spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&snapshot'
permissions: {}
strategy:
fail-fast: false
matrix:
include:
- params: {"spark-version": "3.2.4", "scala-version": "2.12.15", "scala-compat-version": "2.12", "java-compat-version": "8"}
- params: {"spark-version": "3.3.4", "scala-version": "2.12.15", "scala-compat-version": "2.12", "java-compat-version": "8"}
- params: {"spark-version": "3.4.4", "scala-version": "2.12.17", "scala-compat-version": "2.12", "java-compat-version": "8"}
- params: {"spark-version": "3.5.8", "scala-version": "2.12.18", "scala-compat-version": "2.12", "java-compat-version": "8"}
- params: {"spark-version": "3.2.4", "scala-version": "2.13.5", "scala-compat-version": "2.13", "java-compat-version": "8"}
- params: {"spark-version": "3.3.4", "scala-version": "2.13.8", "scala-compat-version": "2.13", "java-compat-version": "8"}
- params: {"spark-version": "3.4.4", "scala-version": "2.13.8", "scala-compat-version": "2.13", "java-compat-version": "8"}
- params: {"spark-version": "3.5.8", "scala-version": "2.13.8", "scala-compat-version": "2.13", "java-compat-version": "8"}
- params: {"spark-version": "4.0.2", "scala-version": "2.13.16", "scala-compat-version": "2.13", "java-compat-version": "17"}
- params: {"spark-version": "4.1.1", "scala-version": "2.13.17", "scala-compat-version": "2.13", "java-compat-version": "17"}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up JDK and publish to Maven Central
uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12 # v4.7.0
with:
java-version: ${{ matrix.params.java-compat-version }}
distribution: 'corretto'
server-id: central
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.MAVEN_GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Inspect GPG
run: gpg -k
- uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Restore Maven packages cache
id: cache-maven
uses: actions/cache/restore@v4
with:
path: ~/.m2/repository
key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
restore-keys: |
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}
${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-
- name: Publish snapshot
run: |
./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}
mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
env:
MAVEN_USERNAME: ${{ secrets.MAVEN_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.MAVEN_PASSWORD }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.MAVEN_GPG_PASSPHRASE}}
- name: Prepare PyPi package to test snapshot
if: ${{ matrix.params.scala-version }} == 2.12*
run: |
# Build whl
./build-whl.sh
- name: Restore Spark Binaries cache
uses: actions/cache/restore@v4
with:
path: ~/spark
key: ${{ runner.os }}-spark-binaries-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}
restore-keys: |
${{ runner.os }}-spark-binaries-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}
- name: Rename Spark Binaries cache
run: |
mv ~/spark ./spark-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}
- name: Test snapshot
id: test-package
run: |
# Test the snapshot (needs whl)
./test-release.sh
================================================
FILE: .github/workflows/test-jvm.yml
================================================
name: Test JVM
on:
workflow_call:
jobs:
test:
name: Test (Spark ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }} Scala ${{ matrix.scala-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
# keep in-sync with .github/workflows/prime-caches.yml
matrix:
include:
- spark-compat-version: '3.2'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '2.7'
- spark-compat-version: '3.3'
scala-compat-version: '2.12'
scala-version: '2.12.15'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.12'
scala-version: '2.12.18'
spark-patch-version: '7'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.2'
scala-compat-version: '2.13'
scala-version: '2.13.5'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '3.2'
- spark-compat-version: '3.3'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '4'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.5'
scala-compat-version: '2.13'
scala-version: '2.13.8'
spark-patch-version: '7'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '4.0'
scala-compat-version: '2.13'
scala-version: '2.13.16'
spark-patch-version: '2'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '4.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
spark-patch-version: '1'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '4.2'
scala-compat-version: '2.13'
scala-version: '2.13.18'
spark-patch-version: '0-preview3'
java-compat-version: '17'
hadoop-version: '3'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Test
uses: ./.github/actions/test-jvm
env:
CI_SLOW_TESTS: 1
with:
spark-version: ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
spark-archive-url: ${{ matrix.spark-archive-url }}
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
hadoop-version: ${{ matrix.hadoop-version }}
================================================
FILE: .github/workflows/test-python.yml
================================================
name: Test Python
on:
workflow_call:
jobs:
# pyspark is not available for snapshots or scala other than 2.12
# we would have to compile spark from sources for this, not worth it
test:
name: Test (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }} Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
spark-compat-version: ['3.2', '3.3', '3.4', '3.5', '4.0']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
include:
- spark-compat-version: '3.2'
spark-version: '3.2.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '2.7'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.4'
spark-version: '3.4.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.5'
spark-version: '3.5.8'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '4.0'
spark-version: '4.0.2'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
hadoop-version: '3'
- spark-compat-version: '4.1'
spark-version: '4.1.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.10'
- spark-compat-version: '4.2'
spark-version: '4.2.0-preview3'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.10'
exclude:
- spark-compat-version: '3.2'
python-version: '3.10'
- spark-compat-version: '3.2'
python-version: '3.11'
- spark-compat-version: '3.2'
python-version: '3.12'
- spark-compat-version: '3.2'
python-version: '3.13'
- spark-compat-version: '3.3'
python-version: '3.11'
- spark-compat-version: '3.3'
python-version: '3.12'
- spark-compat-version: '3.3'
python-version: '3.13'
- spark-compat-version: '3.4'
python-version: '3.12'
- spark-compat-version: '3.4'
python-version: '3.13'
- spark-compat-version: '3.5'
python-version: '3.12'
- spark-compat-version: '3.5'
python-version: '3.13'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Test
uses: ./.github/actions/test-python
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
spark-archive-url: ${{ matrix.spark-archive-url }}
spark-package-repo: ${{ matrix.spark-package-repo }}
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
hadoop-version: ${{ matrix.hadoop-version }}
python-version: ${{ matrix.python-version }}
================================================
FILE: .github/workflows/test-release.yml
================================================
name: Test release
on:
workflow_call:
jobs:
test:
name: Test Release Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- spark-compat-version: '3.2'
spark-version: '3.2.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '2.7'
python-version: '3.9'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
hadoop-version: '3'
python-version: '3.10'
- spark-compat-version: '3.4'
spark-version: '3.4.4'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
hadoop-version: '3'
python-version: '3.11'
- spark-compat-version: '3.5'
spark-version: '3.5.8'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
hadoop-version: '3'
python-version: '3.11'
- spark-compat-version: '3.2'
spark-version: '3.2.4'
scala-compat-version: '2.13'
scala-version: '2.13.5'
java-compat-version: '8'
hadoop-version: '3.2'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.4'
spark-version: '3.4.4'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '3.5'
spark-version: '3.5.8'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
hadoop-version: '3'
- spark-compat-version: '4.0'
spark-version: '4.0.2'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.13'
- spark-compat-version: '4.1'
spark-version: '4.1.1'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.13'
- spark-compat-version: '4.2'
spark-version: '4.2.0-preview3'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
hadoop-version: '3'
python-version: '3.13'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Test
uses: ./.github/actions/test-release
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
spark-archive-url: ${{ matrix.spark-archive-url }}
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
hadoop-version: ${{ matrix.hadoop-version }}
python-version: ${{ matrix.python-version }}
================================================
FILE: .github/workflows/test-results.yml
================================================
name: Test Results
on:
workflow_run:
workflows: ["CI"]
types:
- completed
permissions: {}
jobs:
publish-test-results:
name: Publish Test Results
runs-on: ubuntu-latest
if: github.event.workflow_run.conclusion != 'skipped'
permissions:
checks: write
pull-requests: write
steps:
- name: Download and Extract Artifacts
uses: dawidd6/action-download-artifact@09f2f74827fd3a8607589e5ad7f9398816f540fe
with:
run_id: ${{ github.event.workflow_run.id }}
name: "^Event File$| Test Results "
name_is_regexp: true
path: artifacts
- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@v2
with:
commit: ${{ github.event.workflow_run.head_sha }}
event_file: artifacts/Event File/event.json
event_name: ${{ github.event.workflow_run.event }}
files: "artifacts/* Test Results*/**/*.xml"
================================================
FILE: .github/workflows/test-snapshots.yml
================================================
name: Test Snapshots
on:
workflow_call:
jobs:
test:
name: Test (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- spark-compat-version: '3.2'
spark-version: '3.2.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
- spark-compat-version: '3.3'
spark-version: '3.3.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.15'
java-compat-version: '8'
- spark-compat-version: '3.4'
spark-version: '3.4.5-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.17'
java-compat-version: '8'
- spark-compat-version: '3.5'
spark-version: '3.5.9-SNAPSHOT'
scala-compat-version: '2.12'
scala-version: '2.12.18'
java-compat-version: '8'
- spark-compat-version: '3.2'
spark-version: '3.2.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.5'
java-compat-version: '8'
- spark-compat-version: '3.3'
spark-version: '3.3.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '3.4'
spark-version: '3.4.5-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '3.5'
spark-version: '3.5.9-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.8'
java-compat-version: '8'
- spark-compat-version: '4.0'
spark-version: '4.0.3-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.16'
java-compat-version: '17'
- spark-compat-version: '4.1'
spark-version: '4.1.2-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.17'
java-compat-version: '17'
- spark-compat-version: '4.1'
spark-version: '4.2.0-SNAPSHOT'
scala-compat-version: '2.13'
scala-version: '2.13.18'
java-compat-version: '17'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Test
uses: ./.github/actions/test-jvm
env:
CI_SLOW_TESTS: 1
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT
scala-compat-version: ${{ matrix.scala-compat-version }}
java-compat-version: ${{ matrix.java-compat-version }}
================================================
FILE: .gitignore
================================================
# use glob syntax.
syntax: glob
*.ser
*.class
*~
*.bak
#*.off
*.old
# eclipse conf file
.settings
.classpath
.project
.manager
.scala_dependencies
# idea
.idea
*.iml
# building
target
build
null
tmp*
temp*
dist
test-output
build.log
# other scm
.svn
.CVS
.hg*
# switch to regexp syntax.
# syntax: regexp
# ^\.pc/
#SHITTY output not in target directory
build.log
# project specific
python/**/__pycache__
spark-*
.cache
================================================
FILE: .scalafmt.conf
================================================
version = 3.7.17
runner.dialect = scala213
rewrite.trailingCommas.style = keep
docstrings.style = Asterisk
maxColumn = 120
================================================
FILE: CHANGELOG.md
================================================
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [2.15.0] - 2025-12-13
### Added
- Support encrypted parquet files (#324)
### Changed
- Remove support for Spark 3.0 and Spark 3.1 (#332)
- Make all undocumented unintended public API parts private (#331)
- Reading Parquet metadata can use Parquet Hadoop version different to version coming with Spark (#330)
## [2.14.2] - 2025-07-21
### Changed
- Fixed release process (#320)
## [2.14.1] - 2025-07-17
### Changed
- Fixed release process (#319)
## [2.14.0] - 2025-07-17
### Added
- Support for Spark 4.0 (#269, #272, #293)
### Changed
- Improve backticks (#265)
New: This escapes backticks that already exist in column names.
Change: This does not quote columns that only contain letters, numbers
and underscores, which were quoted before.
- Move Python dependencies into `setup.py`, build jar from `setup.py` (#301)
## [2.13.0] - 2024-11-04
### Fixes
- Support diff for Spark Connect implemened via PySpark Dataset API (#251)
### Added
- Add ignore columns to diff in Python API (#252)
- Check that the Java / Scala package is installed when needed by Python (#250)
## [2.12.0] - 2024-04-26
### Fixes
- Diff change column should respect comparators (#238)
### Changed
- Make create_temporary_dir work with pyspark-extension only (#222).
This allows [installing PIP packages and Poetry projects](PYSPARK-DEPS.md)
via pure Python spark-extension package (Maven package not required any more).
- Add map diff comparator to Python API (#226)
## [2.11.0] - 2024-01-04
### Added
- Add count_null aggregate function (#206)
- Support reading parquet schema (#208)
- Add more columns to reading parquet metadata (#209, #211)
- Provide groupByKey shortcuts for groupBy.as (#213)
- Allow to install PIP packages into PySpark job (#215)
- Allow to install Poetry projects into PySpark job (#216)
## [2.10.0] - 2023-09-27
### Fixed
- Update setup.py to include parquet methods in python package (#191)
### Added
- Add --statistics option to diff app (#189)
- Add --filter option to diff app (#190)
## [2.9.0] - 2023-08-23
### Added
- Add key order sensitive map comparator (#187)
### Changed
- Use dataset encoder rather than implicit value encoder for implicit dataset extension class (#183)
### Fixed
- Fix key-sensitivity in map comparator (#186)
## [2.8.0] - 2023-05-24
### Added
- Add method to set and automatically unset Spark job description. (#172)
- Add column function that converts between .Net (C#, F#, Visual Basic) `DateTime.Ticks` and Spark timestamp / Unix epoch timestamps. (#153)
## [2.7.0] - 2023-05-05
### Added
- Spark app to diff files or tables and write result back to file or table. (#160)
- Add null value count to `parquetBlockColumns` and `parquet_block_columns`. (#162)
- Add `parallelism` argument to Parquet metadata methods. (#164)
### Changed
- Change data type of column name in `parquetBlockColumns` and `parquet_block_columns` to array of strings.
Cast to string to get earlier behaviour (string column name). (#162)
## [2.6.0] - 2023-04-11
### Added
- Add reader for parquet metadata. (#154)
## [2.5.0] - 2023-03-23
### Added
- Add whitespace agnostic diff comparator. (#137)
- Add Python whl package build. (#151)
## [2.4.0] - 2022-12-08
### Added
- Allow for custom diff equality. (#127)
### Fixed
- Fix Python API calling into Scala code. (#132)
## [2.3.0] - 2022-10-26
### Added
- Add diffWith to Scala, Java and Python Diff API. (#109)
### Changed
- Diff similar Datasets with ignoreColumns. Before, only similar DataFrame could be diffed with ignoreColumns. (#111)
### Fixed
- Cache before writing via partitionedBy to work around SPARK-40588. Unpersist via UnpersistHandle. (#124)
## [2.2.0] - 2022-07-21
### Added
- Add (global) row numbers transformation to Scala, Java and Python API. (#97)
### Removed
- Removed support for Pyton 3.6
## [2.1.0] - 2022-04-07
### Added
- Add sorted group methods to Dataset. (#76)
## [2.0.0] - 2021-10-29
### Added
- Add support for Spark 3.2 and Scala 2.13.
- Support to ignore columns in diff API. (#63)
### Removed
- Removed support for Spark 2.4.
## [1.3.3] - 2020-12-17
### Added
- Add support for Spark 3.1.
## [1.3.2] - 2020-12-17
### Changed
- Refine conditional transformation helper methods.
## [1.3.1] - 2020-12-10
### Changed
- Refine conditional transformation helper methods.
## [1.3.0] - 2020-12-07
### Added
- Add transformation to compute histogram. (#26)
- Add conditional transformation helper methods. (#27)
- Add partitioned writing helpers that simplifies writing optimally ordered partitioned data. (#29)
## [1.2.0] - 2020-10-06
### Added
- Add diff modes (#22): column-by-column, side-by-side, left and right side diff modes.
- Adds sparse mode (#23): diff DataFrame contains only changed values.
## [1.1.0] - 2020-08-24
### Added
- Add Python API for Diff transformation.
- Add change column to Diff transformation providing column names of all changed columns in a row.
- Add fluent methods to change immutable diff options.
- Add `backticks` method to handle column names that contain dots (`.`).
## [1.0.0] - 2020-03-12
### Added
- Add Diff transformation for Datasets.
================================================
FILE: CONDITIONAL.md
================================================
# DataFrame Transformations
The Spark `Dataset` API allows for chaining transformations as in the following example:
```scala
ds.where($"id" === 1)
.withColumn("state", lit("new"))
.orderBy($"timestamp")
```
When you define additional transformation functions, the `Dataset` API allows you to
also fluently call into those:
```scala
def transformation(df: DataFrame): DataFrame = df.distinct
ds.transform(transformation)
```
Here are some methods that extend this principle to conditional calls.
## Conditional Transformations
You can run a transformation after checking a condition with a chain of fluent transformation calls:
```scala
import uk.co.gresearch._
val condition = true
val result =
ds.where($"id" === 1)
.withColumn("state", lit("new"))
.when(condition).call(transformation)
.orderBy($"timestamp")
```
rather than
```scala
val condition = true
val filteredDf = ds.where($"id" === 1)
.withColumn("state", lit("new"))
val condDf = if (condition) ds.call(transformation) else ds
val result = ds.orderBy($"timestamp")
```
In case you need an else transformation as well, try:
```scala
import uk.co.gresearch._
val condition = true
val result =
ds.where($"id" === 1)
.withColumn("state", lit("new"))
.on(condition).either(transformation).or(other)
.orderBy($"timestamp")
```
## Fluent and conditional functions elsewhere
The same fluent notation works for instances other than `Dataset` or `DataFrame`, e.g.
for the `DataFrameWriter`:
```scala
def writeData[T](writer: DataFrameWriter[T]): Unit = { ... }
ds.write
.when(compress).call(_.option("compression", "gzip"))
.call(writeData)
```
================================================
FILE: DIFF.md
================================================
# Spark Diff
Add the following `import` to your Scala code:
```scala
import uk.co.gresearch.spark.diff._
```
or this `import` to your Python code:
```python
# noinspection PyUnresolvedReferences
from gresearch.spark.diff import *
```
This adds a `diff` transformation to `Dataset` and `DataFrame` that computes the differences between two datasets / dataframes,
i.e. which rows of one dataset / dataframes to _add_, _delete_ or _change_ to get to the other dataset / dataframes.
For example, in Scala
```scala
val left = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
val right = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
```
or in Python:
```python
left = spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"])
right = spark.createDataFrame([(1, "one"), (2, "Two"), (4, "four")], ["id", "value"])
```
diffing becomes as easy as:
```scala
left.diff(right).show()
```
|diff |id |value |
|:---:|:---:|:-----:|
| N| 1| one|
| D| 2| two|
| I| 2| Two|
| D| 3| three|
| I| 4| four|
With columns that provide unique identifiers per row (here `id`), the diff looks like:
```scala
left.diff(right, "id").show()
```
|diff |id |left_value|right_value|
|:---:|:---:|:--------:|:---------:|
| N| 1| one| one|
| C| 2| two| Two|
| D| 3| three| *null*|
| I| 4| *null*| four|
Equivalent alternative is this hand-crafted transformation (Scala)
```scala
left.withColumn("exists", lit(1)).as("l")
.join(right.withColumn("exists", lit(1)).as("r"),
$"l.id" <=> $"r.id",
"fullouter")
.withColumn("diff",
when($"l.exists".isNull, "I").
when($"r.exists".isNull, "D").
when(!($"l.value" <=> $"r.value"), "C").
otherwise("N"))
.show()
```
Statistics on the differences can be obtained by
```scala
left.diff(right, "id").groupBy("diff").count().show()
```
|diff |count |
|:----:|:-----:|
| N| 1|
| I| 1|
| D| 1|
| C| 1|
The `diff` transformation can optionally provide a *change column* that lists all non-id column names that have changed.
This column is an array of strings and only set for `"N"` and `"C"`action rows; it is *null* for `"I"` and `"D"`action rows.
|diff |changes|id |left_value|right_value|
|:---:|:-----:|:---:|:--------:|:---------:|
| N| []| 1| one| one|
| C|[value]| 2| two| Two|
| D| *null*| 3| three| *null*|
| I| *null*| 4| *null*| four|
## Features
This `diff` transformation provides the following features:
* id columns are optional
* provides typed `diffAs` and `diffWith` transformations
* supports *null* values in id and non-id columns
* detects *null* value insertion / deletion
* [configurable](#configuring-diff) via `DiffOptions`:
* diff column name (default: `"diff"`), if default name exists in diff result schema
* diff action labels (defaults: `"N"`, `"I"`, `"D"`, `"C"`), allows custom diff notation, e.g. Unix diff left-right notation (<, >) or git before-after format (+, -, -+)
* [custom equality operators](#comparators-equality) (e.g. double comparison with epsilon threshold)
* [different diff result formats](#diffing-modes)
* [sparse diffing mode](#sparse-mode)
* optionally provides a *change column* that lists all non-id column names that have changed (only for `"C"` action rows)
* guarantees that no duplicate columns exist in the result, throws a readable exception otherwise
## Configuring Diff
Diffing can be configured via an optional `DiffOptions` instance (see [Methods](#methods) below).
|option |default |description|
|--------------------|:-------:|-----------|
|`diffColumn` |`"diff"` |The 'diff column' provides the action or diff value encoding if the respective row has been inserted, changed, deleted or has not been changed at all.|
|`leftColumnPrefix` |`"left"` |Non-id columns of the 'left' dataset are prefixed with this prefix.|
|`rightColumnPrefix` |`"right"`|Non-id columns of the 'right' dataset are prefixed with this prefix.|
|`insertDiffValue` |`"I"` |Inserted rows are marked with this string in the 'diff column'.|
|`changeDiffValue` |`"C"` |Changed rows are marked with this string in the 'diff column'.|
|`deleteDiffValue` |`"D"` |Deleted rows are marked with this string in the 'diff column'.|
|`nochangeDiffValue` |`"N"` |Unchanged rows are marked with this string in the 'diff column'.|
|`changeColumn` |*none* |An array with the names of all columns that have changed values is provided in this column (only for unchanged and changed rows, *null* otherwise).|
|`diffMode` |`DiffModes.Default`|Configures the diff output format. For details see [Diff Modes](#diff-modes) section below.|
|`sparseMode` |`false` |When `true`, only values that have changed are provided on left and right side, `null` is used for un-changed values.|
|`defaultComparator` |`DiffComparators.default()`|The default equality for all value columns.|
|`dataTypeComparators`|_empty_ |Map from data types to comparators.|
|`columnNameComparators`|_empty_|Map from column names to comparators.|
Either construct an instance via the constructor …
```scala
// Scala
import uk.co.gresearch.spark.diff.{DiffOptions, DiffMode}
val options = DiffOptions("d", "l", "r", "i", "c", "d", "n", Some("changes"), DiffMode.Default, false)
```
```python
# Python
from gresearch.spark.diff import DiffOptions, DiffMode
options = DiffOptions("d", "l", "r", "i", "c", "d", "n", "changes", DiffMode.Default, False)
```
… or via the `.with*` methods. The former requires most options to be specified, whereas the latter
only requires the ones that deviate from the default. And it is more readable.
Start from the default options `DiffOptions.default` and customize as follows:
```scala
// Scala
import uk.co.gresearch.spark.diff.{DiffOptions, DiffMode, DiffComparators}
val options = DiffOptions.default
.withDiffColumn("d")
.withLeftColumnPrefix("l")
.withRightColumnPrefix("r")
.withInsertDiffValue("i")
.withChangeDiffValue("c")
.withDeleteDiffValue("d")
.withNochangeDiffValue("n")
.withChangeColumn("changes")
.withDiffMode(DiffMode.Default)
.withSparseMode(true)
.withDefaultComparator(DiffComparators.epsilon(0.001))
.withComparator(DiffComparators.epsilon(0.001), DoubleType)
.withComparator(DiffComparators.epsilon(0.001), "float_column")
```
```python
# Python
from pyspark.sql.types import DoubleType
from gresearch.spark.diff import DiffOptions, DiffMode, DiffComparators
options = DiffOptions() \
.with_diff_column("d") \
.with_left_column_prefix("l") \
.with_right_column_prefix("r") \
.with_insert_diff_value("i") \
.with_change_diff_value("c") \
.with_delete_diff_value("d") \
.with_nochange_diff_value("n") \
.with_change_column("changes") \
.with_diff_mode(DiffMode.Default) \
.with_sparse_mode(True) \
.with_default_comparator(DiffComparators.epsilon(0.01)) \
.with_data_type_comparator(DiffComparators.epsilon(0.001), DoubleType()) \
.with_column_name_comparator(DiffComparators.epsilon(0.001), "float_column")
```
### Diffing Modes
The result of the diff transformation can have the following formats:
- *column by column*: The non-id columns are arranged column by column, i.e. for each non-id column
there are two columns next to each other in the diff result, one from the left
and one from the right dataset. This is useful to easily compare the values
for each column.
- *side by side*: The non-id columns from the left and right dataset are are arranged side by side,
i.e. first there are all columns from the left dataset, then from the right one.
This is useful to visually compare the datasets as a whole, especially in conjunction
with the sparse mode.
- *left side*: Only the columns of the left dataset are present in the diff output. This mode
provides the left dataset as is, annotated with diff action and optional changed column names.
- *right side*: Only the columns of the right dataset are present in the diff output. This mode
provides the right dataset as given, as well as the diff action that has been applied to it.
This serves as a patch that, applied to the left dataset, results in the right dataset.
With the following two datasets `left` and `right`:
```scala
case class Value(id: Int, value: Option[String], label: Option[String])
val left = Seq(
Value(1, Some("one"), None),
Value(2, Some("two"), Some("number two")),
Value(3, Some("three"), Some("number three")),
Value(4, Some("four"), Some("number four")),
Value(5, Some("five"), Some("number five")),
).toDS
val right = Seq(
Value(1, Some("one"), Some("one")),
Value(2, Some("Two"), Some("number two")),
Value(3, Some("Three"), Some("number Three")),
Value(4, Some("four"), Some("number four")),
Value(6, Some("six"), Some("number six")),
).toDS
```
the diff modes produce the following outputs:
#### Column by Column
|diff |id |left_value|right_value|left_label |right_label |
|:---:|:---:|:--------:|:---------:|:----------:|:----------:|
|C |1 |one |one |*null* |one |
|C |2 |two |Two |number two |number two |
|C |3 |three |Three |number three|number Three|
|N |4 |four |four |number four |number four |
|D |5 |five |null |number five |*null* |
|I |6 |*null* |six |*null* |number six |
#### Side by Side
|diff |id |left_value|left_label |right_value|right_label |
|:---:|:---:|:--------:|:----------:|:---------:|:----------:|
|C |1 |one |*null* |one |one |
|C |2 |two |number two |Two |number two |
|C |3 |three |number three|Three |number Three|
|N |4 |four |number four |four |number four |
|D |5 |five |number five |null |*null* |
|I |6 |*null* |*null* |six |number six |
#### Left Side
|diff |id |value|label |
|:---:|:---:|:---:|:----------:|
|C |1 |one |null |
|C |2 |two |number two |
|C |3 |three|number three|
|N |4 |four |number four |
|D |5 |five |number five |
|I |6 |null |null |
#### Right Side
|diff |id |value|label |
|:---:|:---:|:---:|:----------:|
|C |1 |one |one |
|C |2 |Two |number two |
|C |3 |Three|number Three|
|N |4 |four |number four |
|D |5 |null |null |
|I |6 |six |number six |
### Sparse Mode
The diff modes above can be combined with sparse mode. In sparse mode, only values that differ between
the two datasets are in the diff result, all other values are `null`.
Above [Column by Column](#column-by-column) example would look in sparse mode as follows:
|diff |id |left_value|right_value|left_label |right_label |
|:---:|:---:|:--------:|:---------:|:----------:|:----------:|
|C |1 |null |null |null |one |
|C |2 |two |Two |null |null |
|C |3 |three |Three |number three|number Three|
|N |4 |null |null |null |null |
|D |5 |five |null |number five |null |
|I |6 |null |six |null |number six |
### Comparators (Equality)
Values are compared for equality with the default `<=>` operator, which considers values
equal when both sides are `null`, or both sides are not `null` and equal.
The following alternative comparators are provided:
|Comparator|Description|
|:---------|:----------|
|`DiffComparators.epsilon(epsilon)`|Two values are equal when they are at most `epsilon` apart.
The comparator can be configured to use `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value. Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):
`DiffComparators.epsilon(epsilon).asAbsolute().asInclusive()`: `x` and `y` are equal iff `abs(x - y) ≤ epsilon`
`DiffComparators.epsilon(epsilon).asAbsolute().asExclusive()`: `x` and `y` are equal iff `abs(x - y) < epsilon`
`DiffComparators.epsilon(epsilon).asRelative().asInclusive()`: `x` and `y` are equal iff `abs(x - y) ≤ epsilon * max(abs(x), abs(y))`
`DiffComparators.epsilon(epsilon).asRelative().asExclusive()`: `x` and `y` are equal iff `abs(x - y) < epsilon * max(abs(x), abs(y))`
|
|`DiffComparators.string()`|Two `StringType` values are compared while ignoring white space differences. For this comparison, sequences of whitespaces are collapesed into single whitespaces, leading and trailing whitespaces are removed. With `DiffComparators.string(false)`, string values are compared with the default comparator.|
|`DiffComparators.duration(duration)`|Two `DateType` or `TimestampType` values are equal when they are at most `duration` apart. That duration is an instance of `java.time.Duration`.
The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal (`.asExclusive()`):
`DiffComparators.duration(duration).asInclusive()`: `x` and `y` are equal iff `x - y ≤ duration`
`DiffComparators.duration(duration).asExclusive()`: `x` and `y` are equal iff `x - y < duration`
|
|`DiffComparators.map[K,V](keyOrderSensitive)` (Scala only) `DiffComparators.map(keyType, valueType, keyOrderSensitive)`|Two `Map[K,V]` values are equal when they match in all their keys and values. With `keyOrderSensitive=true`, the order of the keys matters, with `keyOrderSensitive=false` (default), the order of keys is ignored.|
An example:
val left = Seq((1, 1.0), (2, 2.0), (3, 3.0)).toDF("id", "value")
val right = Seq((1, 1.0), (2, 2.02), (3, 3.05)).toDF("id", "value")
left.diff(right, "id").show()
|diff| id|left_value|right_value|
|----|---|----------|-----------|
| N| 1| 1.0| 1.0|
| C| 2| 2.0| 2.02|
| C| 3| 3.0| 3.05|
The second and third rows are considered `"C"`hanged because `2.0 != 2.02` and `3.0 != 3.05`, respectively.
With an inclusive relative epsilon of 1%, `2.0 != 2.02` is considered equal, while `3.0 != 3.05` is still not equal:
val options = DiffOptions.default
.withComparator(DiffComparators.epsilon(0.01).asRelative().asInclusive(), DoubleType)
left.diff(right, options, "id").show()
|diff| id|left_value|right_value|
|----|---|----------|-----------|
| N| 1| 1.0| 1.0|
| N| 2| 2.0| 2.02|
| C| 3| 3.0| 3.05|
The user can provide custom comparator implementations by implementing `scala.math.Equiv[T]`
or `uk.co.gresearch.spark.diff.DiffComparator`:
val intEquiv: Equiv[Int] = (x: Int, y: Int) => x == null && y == null || x != null && y != null && x.equals(y)
val anyEquiv: Equiv[Any] = (x: Any, y: Any) => x == null && y == null || x != null && y != null && x.equals(y)
val comparator: DiffComparator = (left: Column, right: Column) => left <=> right
import spark.implicits._
val options = DiffOptions.default
.withComparator(intEquiv)
.withComparator(anyEquiv, LongType, DoubleType)
.withComparator(anyEquiv, "column1", "column2")
.withComparator(comparator, StringType, FloatType)
.withComparator(comparator, "column3", "column4")
## Methods (Scala)
All Scala methods come in two variants, one without (as shown below) and one with an `options: DiffOptions` argument.
* `def diff(other: Dataset[T], idColumns: String*): DataFrame`
* `def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame`
* `def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V]`
* `def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit diffEncoder: Encoder[V]): Dataset[V]`
* `def diffAs[V](other: Dataset[T], diffEncoder: Encoder[U], idColumns: String*): Dataset[V]`
* `def diffAs[U, V](other: Dataset[U], diffEncoder: Encoder[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V]`
* `def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)]`
* `def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)]`
## Methods (Java)
* `Dataset Diff.of[T](Dataset left, Dataset right, String... idColumns)`
* `Dataset Diff.of[T, U](Dataset left, Dataset right, List idColumns, List ignoreColumns)`
* `Dataset Diff.ofAs[T, V](Dataset left, Dataset right, Encoder diffEncoder, String... idColumns)`
* `Dataset Diff.ofAs[T, U, V](Dataset left, Dataset right, Encoder diffEncoder, List idColumns, List ignoreColumns)`
* `Dataset> Diff.ofWith[T](Dataset left, Dataset right, String... idColumns)`
* `Dataset> Diff.ofWith[T](Dataset left, Dataset right, List idColumns, List ignoreColumns)`
Given a `DiffOptions`, a customized `Differ` can be instantiated as `Differ differ = new Differ(options)`:
* `Dataset Differ.diff[T](Dataset left, Dataset right, String... idColumns)`
* `Dataset Differ.diff[T, U](Dataset left, Dataset right, List idColumns, List ignoreColumns)`
* `Dataset Differ.diffAs[T, V](Dataset left, Dataset right, Encoder diffEncoder, String... idColumns)`
* `Dataset Differ.diffAs[T, U, V](Dataset left, Dataset right, Encoder diffEncoder, List idColumns, List ignoreColumns)`
* `Dataset Differ.diffWith[T](Dataset left, Dataset right, String... idColumns)`
* `Dataset Differ.diffWith[T, U](Dataset left, Dataset right, List idColumns, List ignoreColumns)`
## Methods (Python)
* `def diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame`
* `def diff(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`
* `def diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame`
* `def diff(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`
* `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:`
* `def diffwith(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`
* `def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame:`
* `def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`
## Diff Spark application
There is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames
`left` and `right` from files or tables, executes the diff transformation and writes the result DataFrame to a file or table.
The Diff app can be run via `spark-submit`:
```shell
# Scala 2.12
spark-submit --packages com.github.scopt:scopt_2.12:4.1.0 spark-extension_2.12-2.7.0-3.4.jar --help
# Scala 2.13
spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --help
```
```
Spark Diff app (2.10.0-3.4)
Usage: spark-extension_2.13-2.10.0-3.4.jar [options] left right diff
left file path (requires format option) or table name to read left dataframe
right file path (requires format option) or table name to read right dataframe
diff file path (requires format option) or table name to write diff dataframe
Examples:
- Diff CSV files 'left.csv' and 'right.csv' and write result into CSV file 'diff.csv':
spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.10.0-3.4.jar --format csv left.csv right.csv diff.csv
- Diff CSV file 'left.csv' with Parquet file 'right.parquet' with id column 'id', and write result into Hive table 'diff':
spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.10.0-3.4.jar --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff
Spark session
--master Spark master (local, yarn, ...), not needed with spark-submit
--app-name Spark application name
--hive enable Hive support to read from and write to Hive tables
Input and output
-f, --format input and output file format (csv, json, parquet, ...)
--left-format left input file format (csv, json, parquet, ...)
--right-format right input file format (csv, json, parquet, ...)
--output-format output file format (csv, json, parquet, ...)
-s, --schema input schema
--left-schema left input schema
--right-schema right input schema
--left-option:key=val left input option
--right-option:key=val right input option
--output-option:key=val output option
--id id column name
--ignore ignore column name
--save-mode save mode for writing output (Append, Overwrite, ErrorIfExists, Ignore, default ErrorIfExists)
--filter Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)
--statistics Only output statistics on how many rows exist per diff action (see 'Diffing options' section)
Diffing options
--diff-column column name for diff column (default 'diff')
--left-prefix prefix for left column names (default 'left')
--right-prefix prefix for right column names (default 'right')
--insert-value value for insertion (default 'I')
--change-value value for change (default 'C')
--delete-value value for deletion (default 'D')
--no-change-value value for no change (default 'N')
--change-column column name for change column (default is no such column)
--diff-mode diff mode (ColumnByColumn, SideBySide, LeftSide, RightSide, default ColumnByColumn)
--sparse enable sparse diff
General
--help prints this usage text
```
### Examples
Diff CSV files `left.csv` and `right.csv` and write result into CSV file `diff.csv`:
```shell
spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --format csv left.csv right.csv diff.csv
```
Diff CSV file `left.csv` with Parquet file `right.parquet` with id column `id`, and write result into Hive table `diff`:
```shell
spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff
```
================================================
FILE: GROUPS.md
================================================
# Sorted Groups
Spark provides the ability to group rows by an arbitrary key,
while then providing an iterator for each of these groups.
This allows to iterate over groups that are too large to fit into memory:
```scala
import org.apache.spark.sql.Dataset
import spark.implicits._
case class Val(id: Int, seq: Int, value: Double)
val ds: Dataset[Val] = Seq(
Val(1, 1, 1.1),
Val(1, 2, 1.2),
Val(1, 3, 1.3),
Val(2, 1, 2.1),
Val(2, 2, 2.2),
Val(2, 3, 2.3),
Val(3, 1, 3.1)
).reverse.toDS().repartition(3).cache()
// order of iterator IS NOT guaranteed
ds.groupByKey(v => v.id)
.flatMapGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))
.toDF("key", "index", "seq", "value")
.show(false)
+---+-----+---+-----+
|key|index|seq|value|
+---+-----+---+-----+
|1 |0 |3 |1.3 |
|1 |1 |2 |1.2 |
|1 |2 |1 |1.1 |
|2 |0 |1 |2.1 |
|2 |1 |3 |2.3 |
|2 |2 |2 |2.2 |
|3 |0 |1 |3.1 |
+---+-----+---+-----+
```
However, we have no control over the order of the group iterators.
If we want the iterators to be ordered according to `seq`, we can do the following:
```scala
import uk.co.gresearch.spark._
// the group key $"id" needs an ordering
implicit val ordering: Ordering.Int.type = Ordering.Int
// order of iterator IS guaranteed
ds.groupBySorted($"id")($"seq")
.flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))
.toDF("key", "index", "seq", "value")
.show(false)
+---+-----+---+-----+
|key|index|seq|value|
+---+-----+---+-----+
|1 |0 |1 |1.1 |
|1 |1 |2 |1.2 |
|1 |2 |3 |1.3 |
|2 |0 |1 |2.1 |
|2 |1 |2 |2.2 |
|2 |2 |3 |2.3 |
|3 |0 |1 |3.1 |
+---+-----+---+-----+
```
Now, iterators are ordered according to `seq`, which is proven by the value of `index`,
that has been generated by `it.zipWithIndex`.
Instead of column expressions, we can also use lambdas to define group key and group order:
```scala
ds.groupByKeySorted(v => v.id)(v => v.seq)
.flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))
.toDF("key", "index", "seq", "value")
.show(false)
```
**Note:** Using lambdas here hides from Spark which columns we use for grouping and sorting.
Query optimization cannot improve partitioning and sorting in this case. Use column expressions when possible.
================================================
FILE: HISTOGRAM.md
================================================
# Histogram
For a table `df` like
|user |score|
|:-----:|:---:|
|Alice |101 |
|Alice |221 |
|Alice |211 |
|Alice |176 |
|Bob |276 |
|Bob |232 |
|Bon |258 |
|Charlie|221 |
you can compute the histogram for each user
|user |≤100 |≤200 |>200 |
|:-----:|:---:|:---:|:---:|
|Alice |0 |2 |2 |
|Bob |0 |0 |3 |
|Charlie|0 |0 |1 |
as follows:
df.withColumn("≤100", when($"score" <= 100, 1).otherwise(0))
.withColumn("≤200", when($"score" > 100 && $"score" <= 200, 1).otherwise(0))
.withColumn(">200", when($"score" > 200, 1).otherwise(0))
.groupBy($"user")
.agg(
sum($"≤100").as("≤100"),
sum($"≤200").as("≤200"),
sum($">200").as(">200")
)
.orderBy($"user")
Equivalent to that query is:
import uk.co.gresearch.spark._
df.histogram(Seq(100, 200), $"score", $"user").orderBy($"user")
The first argument is a sequence of thresholds, the second argument provides the value column.
The subsequent arguments refer to the aggregation columns (`groupBy`). Only aggregation columns
will be in the result DataFrame.
In Java, call:
import uk.co.gresearch.spark.Histogram;
Histogram.of(df, Arrays.asList(100, 200), new Column("score")), new Column("user")).orderBy($"user")
In Python, call:
import gresearch.spark
df.histogram([100, 200], 'user').orderBy('user')
Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MAINTAINERS.md
================================================
## Current maintainers of the project
| Maintainer | GitHub ID |
| ---------------------- | ------------------------------------------------------- |
| Enrico Minack | [EnricoMi](https://github.com/EnricoMi) |
================================================
FILE: PARQUET.md
================================================
# Parquet Metadata
The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)
or [parquet-cli](https://pypi.org/project/parquet-cli/)
by reading from a simple Spark data source.
Parquet metadata can be read on [file level](#parquet-file-metadata),
[schema level](#parquet-file-schema),
[row group level](#parquet-block--rowgroup-metadata),
[column chunk level](#parquet-block-column-metadata) and
[Spark Parquet partition level](#parquet-partition-metadata).
Multiple files can be inspected at once.
Any location that can be read by Spark (`spark.read.parquet(…)`) can be inspected.
This means the path can point to a single Parquet file, a directory with Parquet files,
or multiple paths separated by a comma (`,`). Paths can contain wildcards like `*`.
Multiple files will be inspected in parallel and distributed by Spark.
No actual rows or values will be read from the Parquet files, only metadata, which is very fast.
This allows to inspect Parquet files that have different schemata with one `spark.read` operation.
First, import the new Parquet metadata data sources:
```scala
// Scala
import uk.co.gresearch.spark.parquet._
```
```python
# Python
import gresearch.spark.parquet
```
Then, the following metadata become available:
## Parquet file metadata
Read the metadata of Parquet files into a Dataframe:
```scala
// Scala
spark.read.parquetMetadata("/path/to/parquet").show()
```
```python
# Python
spark.read.parquet_metadata("/path/to/parquet").show()
```
```
+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+
| filename|blocks|compressedBytes|uncompressedBytes|rows|columns|values|nulls| createdBy| schema| encryption| keyValues|
+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+
|file1.parquet| 1| 1268| 1652| 100| 2| 200| 0|parquet-mr versio...|message spark_sch...|UNENCRYPTED|{org.apache.spark...|
|file2.parquet| 2| 2539| 3302| 200| 2| 400| 0|parquet-mr versio...|message spark_sch...|UNENCRYPTED|{org.apache.spark...|
+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+
```
The Dataframe provides the following per-file information:
|column |type | description |
|:-----------------|:----:|:-------------------------------------------------------------------------------|
|filename |string| The Parquet file name |
|blocks |int | Number of blocks / RowGroups in the Parquet file |
|compressedBytes |long | Number of compressed bytes of all blocks |
|uncompressedBytes |long | Number of uncompressed bytes of all blocks |
|rows |long | Number of rows in the file |
|columns |int | Number of columns in the file |
|values |long | Number of values in the file |
|nulls |long | Number of null values in the file |
|createdBy |string| The createdBy string of the Parquet file, e.g. library used to write the file |
|schema |string| The schema |
|encryption |string| The encryption (requires `org.apache.parquet:parquet-hadoop:1.12.4` and above) |
|keyValues |string-to-string map| Key-value data of the file |
## Parquet file schema
Read the schema of Parquet files into a Dataframe:
```scala
// Scala
spark.read.parquetSchema("/path/to/parquet").show()
```
```python
# Python
spark.read.parquet_schema("/path/to/parquet").show()
```
```
+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+
| filename|columnName| columnPath|repetition| type|length| originalType| logicalType|isPrimitive|primitiveType| primitiveOrder|maxDefinitionLevel|maxRepetitionLevel|
+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+
|file.parquet| a| [a]| REQUIRED| INT64| 0| NULL| NULL| true| INT64|TYPE_DEFINED_ORDER| 0| 0|
|file.parquet| x| [b, x]| REQUIRED| INT32| 0| NULL| NULL| true| INT32|TYPE_DEFINED_ORDER| 1| 0|
|file.parquet| y| [b, y]| REQUIRED|DOUBLE| 0| NULL| NULL| true| DOUBLE|TYPE_DEFINED_ORDER| 1| 0|
|file.parquet| z| [b, z]| OPTIONAL| INT64| 0|TIMESTAMP_MICROS|TIMESTAMP(MICROS,...| true| INT64|TYPE_DEFINED_ORDER| 2| 0|
|file.parquet| element|[c, list, element]| OPTIONAL|BINARY| 0| UTF8| STRING| true| BINARY|TYPE_DEFINED_ORDER| 3| 1|
+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+
```
The Dataframe provides the following per-file information:
|column | type | description |
|:-----------------|:------------:|:----------------------------------------------------------------------------------|
|filename | string | The Parquet file name |
|columnName | string | The column name |
|columnPath | string array | The column path |
|repetition | string | The repetition |
|type | string | The data type |
|length | int | The length of the type |
|originalType | string | The original type (requires `org.apache.parquet:parquet-hadoop:1.11.0` and above) |
|isPrimitive | boolean | True if type is primitive |
|primitiveType | string | The primitive type |
|primitiveOrder | string | The order of the primitive type |
|maxDefinitionLevel| int | The max definition level |
|maxRepetitionLevel| int | The max repetition level |
## Parquet block / RowGroup metadata
Read the metadata of Parquet blocks / RowGroups into a Dataframe:
```scala
// Scala
spark.read.parquetBlocks("/path/to/parquet").show()
```
```python
# Python
spark.read.parquet_blocks("/path/to/parquet").show()
```
```
+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+
| filename|block|blockStart|compressedBytes|uncompressedBytes|rows|columns|values|nulls|
+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+
|file1.parquet| 1| 4| 1269| 1651| 100| 2| 200| 0|
|file2.parquet| 1| 4| 1268| 1652| 100| 2| 200| 0|
|file2.parquet| 2| 1273| 1270| 1651| 100| 2| 200| 0|
+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+
```
|column |type |description |
|:-----------------|:----:|:----------------------------------------------|
|filename |string|The Parquet file name |
|block |int |Block / RowGroup number starting at 1 |
|blockStart |long |Start position of the block in the Parquet file|
|compressedBytes |long |Number of compressed bytes in block |
|uncompressedBytes |long |Number of uncompressed bytes in block |
|rows |long |Number of rows in block |
|columns |int |Number of columns in block |
|values |long |Number of values in block |
|nulls |long |Number of null values in block |
## Parquet block column metadata
Read the metadata of Parquet block columns into a Dataframe:
```scala
// Scala
spark.read.parquetBlockColumns("/path/to/parquet").show()
```
```python
# Python
spark.read.parquet_block_columns("/path/to/parquet").show()
```
```
+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+
| filename|block|column| codec| type| encodings| minValue| maxValue|columnStart|compressedBytes|uncompressedBytes|values|nulls|
+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+
|file1.parquet| 1| [id]|SNAPPY| required int64 id|[BIT_PACKED, PLAIN]| 0| 99| 4| 437| 826| 100| 0|
|file1.parquet| 1| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.005067503372006343|0.9973357672164814| 441| 831| 826| 100| 0|
|file2.parquet| 1| [id]|SNAPPY| required int64 id|[BIT_PACKED, PLAIN]| 100| 199| 4| 438| 825| 100| 0|
|file2.parquet| 1| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.010617521596503865| 0.999189783846449| 442| 831| 826| 100| 0|
|file2.parquet| 2| [id]|SNAPPY| required int64 id|[BIT_PACKED, PLAIN]| 200| 299| 1273| 440| 826| 100| 0|
|file2.parquet| 2| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.011277044401634018| 0.970525681750662| 1713| 830| 825| 100| 0|
+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+
```
| column | type | description |
|:------------------|:-------------:|:--------------------------------------------------------------------------------------------------|
| filename | string | The Parquet file name |
| block | int | Block / RowGroup number starting at 1 |
| column | array | Block / RowGroup column name |
| codec | string | The coded used to compress the block column values |
| type | string | The data type of the block column |
| encodings | array | Encodings of the block column |
| isEncrypted | boolean | Whether block column is encrypted (requires `org.apache.parquet:parquet-hadoop:1.12.3` and above) |
| minValue | string | Minimum value of this column in this block |
| maxValue | string | Maximum value of this column in this block |
| columnStart | long | Start position of the block column in the Parquet file |
| compressedBytes | long | Number of compressed bytes of this block column |
| uncompressedBytes | long | Number of uncompressed bytes of this block column |
| values | long | Number of values in this block column |
| nulls | long | Number of null values in this block column |
## Parquet partition metadata
Read the metadata of how Spark partitions Parquet files into a Dataframe:
```scala
// Scala
spark.read.parquetPartitions("/path/to/parquet").show()
```
```python
# Python
spark.read.parquet_partitions("/path/to/parquet").show()
```
```
+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+
|partition|start| end|length|blocks|compressedBytes|uncompressedBytes|rows|columns|values|nulls| filename|fileLength|
+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+
| 1| 0|1024| 1024| 1| 1268| 1652| 100| 2| 200| 0|file1.parquet| 1930|
| 2| 1024|1930| 906| 0| 0| 0| 0| 0| 0| 0|file1.parquet| 1930|
| 3| 0|1024| 1024| 1| 1269| 1651| 100| 2| 200| 0|file2.parquet| 3493|
| 4| 1024|2048| 1024| 1| 1270| 1651| 100| 2| 200| 0|file2.parquet| 3493|
| 5| 2048|3072| 1024| 0| 0| 0| 0| 0| 0| 0|file2.parquet| 3493|
| 6| 3072|3493| 421| 0| 0| 0| 0| 0| 0| 0|file2.parquet| 3493|
+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+
```
|column |type |description |
|:----------------|:----:|:---------------------------------------------------------|
|partition |int |The Spark partition id |
|start |long |The start position of the partition |
|end |long |The end position of the partition |
|length |long |The length of the partition |
|blocks |int |The number of Parquet blocks / RowGroups in this partition|
|compressedBytes |long |The number of compressed bytes in this partition |
|uncompressedBytes|long |The number of uncompressed bytes in this partition |
|rows |long |The number of rows in this partition |
|columns |int |The number of columns in this partition |
|values |long |The number of values in this partition |
|nulls |long |The number of null values in this partition |
|filename |string|The Parquet file name |
|fileLength |long |The length of the Parquet file |
## Performance
Retrieving Parquet metadata is parallelized and distributed by Spark. The result Dataframe
has as many partitions as there are Parquet files in the given `path`, but at most
`spark.sparkContext.defaultParallelism` partitions.
Each result partition reads Parquet metadata from its Parquet files sequentially,
while partitions are executed in parallel (depending on the number of Spark cores of your Spark job).
You can control the number of partitions via the `parallelism` parameter:
```scala
// Scala
spark.read.parquetMetadata(100, "/path/to/parquet")
spark.read.parquetSchema(100, "/path/to/parquet")
spark.read.parquetBlocks(100, "/path/to/parquet")
spark.read.parquetBlockColumns(100, "/path/to/parquet")
spark.read.parquetPartitions(100, "/path/to/parquet")
```
```python
# Python
spark.read.parquet_metadata("/path/to/parquet", parallelism=100)
spark.read.parquet_schema("/path/to/parquet", parallelism=100)
spark.read.parquet_blocks("/path/to/parquet", parallelism=100)
spark.read.parquet_block_columns("/path/to/parquet", parallelism=100)
spark.read.parquet_partitions("/path/to/parquet", parallelism=100)
```
## Encryption
Reading [encrypted Parquet is supported](https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#columnar-encryption).
Files encrypted with [plaintext footer](https://github.com/apache/parquet-format/blob/master/Encryption.md#55-plaintext-footer-mode)
can be read without any encryption keys, while encrypted Parquet metadata are then show as `NULL` values in the result Dataframe.
Encrypted Parquet files with encrypted footer requires the footer encryption key only. No column encryption keys are needed.
## Known Issues
Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
================================================
FILE: PARTITIONING.md
================================================
# Partitioned Writing
If you have ever used `Dataset[T].write.partitionBy`, here is how you can minimize the number of
written files and obtain same-size files.
Spark has two different concepts both referred to as partitioning. Central to Spark is the
concept of how a `Dataset[T]` is split into partitions where a Spark worker processes
a single partition at a time. This is the fundamental concept of how Spark scales with data.
When writing a `Dataset` `ds` to a file-based storage, that output file is actually a directory:
```scala
ds.write.csv("file.csv")
```
The directory structure looks like:
file.csv
file.csv/part-00000-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv
file.csv/part-00001-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv
file.csv/part-00002-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv
file.csv/part-00003-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv
file.csv/part-00004-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv
file.csv/_SUCCESS
When writing, the output can be `partitionBy` one or more columns of the `Dataset`.
For each distinct `value` in that column `col` an individual sub-directory is created in your output path.
The name is of the format `col=value`. Inside the sub-directory, multiple partitions exists,
all containing only data where column `col` has value `value`. To remove redundancy, those
files do not contain that column anymore.
file.csv/property=descr/part-00001-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=descr/part-00002-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=descr/part-00003-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=descr/part-00004-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=label/part-00001-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=label/part-00002-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=label/part-00003-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/property=label/part-00004-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv
file.csv/_SUCCESS
Data that is mis-organized when written end up with the same number of files
in each of the sub-directories, even if some sub-directories contain only a fraction of
the number of rows than others. What you would like to have is have fewer files in smaller
and more files in larger partition sub-directories. Further, all files should have
roughly the same number of rows.
For this, you have to first range partition the `Dataset` according to your partition columns.
ds.repartitionByRange($"property", $"id")
.write
.partitionBy("property")
.csv("file.csv")
This organizes the data optimally for partition-writing them by column `property`.
file.csv/property=descr/part-00000-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv
file.csv/property=descr/part-00001-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv
file.csv/property=label/part-00002-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv
file.csv/property=label/part-00003-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv
file.csv/property=label/part-00004-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv
file.csv/_SUCCESS
This brings all rows with the same value in the `property` and `id` column into the same file.
If you need each file to further be sorted by additional columns, e.g. `ts`, then you can do this with `sortWithinPartitions`.
ds.repartitionByRange($"property", $"id")
.sortWithinPartitions($"property", $"id", $"ts")
.cache // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588
.write
.partitionBy("property")
.csv("file.csv")
Sometimes you want to write-partition by some expression that is not a column of your data,
e.g. the date-representation of the `ts` column.
ds.withColumn("date", $"ts".cast(DateType))
.repartitionByRange($"date", $"id")
.sortWithinPartitions($"date", $"id", $"ts")
.cache // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588
.write
.partitionBy("date")
.csv("file.csv")
All those above constructs can be replaced with a single meaningful operation:
ds.writePartitionedBy(Seq($"ts".cast(DateType).as("date")), Seq($"id"), Seq($"ts"))
.csv("file.csv")
For Spark 3.0 to 3.3 with AQE enabled (see [SPARK-40588](https://issues.apache.org/jira/browse/SPARK-40588)),
`writePartitionedBy` has to cache an internally created DataFrame. This can be unpersisted after writing
is finished. Provide an `UnpersistHandle` for this purpose:
val unpersist = UnpersistHandle()
ds.writePartitionedBy(…, unpersistHandle = Some(unpersist))
.csv("file.csv")
unpersist()
More details about this issue can be found [here](https://www.gresearch.co.uk/blog/article/guaranteeing-in-partition-order-for-partitioned-writing-in-apache-spark/).
================================================
FILE: PYSPARK-DEPS.md
================================================
# PySpark dependencies
Using PySpark on a cluster requires all cluster nodes to have those Python packages installed that are required by the PySpark job.
Such a deployment can be cumbersome, especially when running in an interactive notebook.
The `spark-extension` package allows installing Python packages programmatically by the PySpark application itself (PySpark ≥ 3.1.0).
These packages are only accessible by that PySpark application, and they are removed on calling `spark.stop()`.
Either install the `spark-extension` Maven package, or the `pyspark-extension` PyPi package (on the driver only),
as described [here](README.md#using-spark-extension).
## Installing packages with `pip`
Python packages can be installed with `pip` as follows:
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
spark.install_pip_package("pandas", "pyarrow")
```
Above example installs PIP packages `pandas` and `pyarrow` via `pip`. Method `install_pip_package` takes any `pip` command line argument:
```python
# install packages with version specs
spark.install_pip_package("pandas==1.4.3", "pyarrow~=8.0.0")
# install packages from package sources (e.g. git clone https://github.com/pandas-dev/pandas.git)
spark.install_pip_package("./pandas/")
# install packages from git repo
spark.install_pip_package("git+https://github.com/pandas-dev/pandas.git@main")
# use a pip cache directory to cache downloaded and built whl files
spark.install_pip_package("pandas", "pyarrow", "--cache-dir", "/home/user/.cache/pip")
# use an alternative index url (other than https://pypi.org/simple)
spark.install_pip_package("pandas", "pyarrow", "--index-url", "https://artifacts.company.com/pypi/simple")
# install pip packages quietly (only disables output of PIP)
spark.install_pip_package("pandas", "pyarrow", "--quiet")
```
## Installing Python projects with Poetry
Python projects can be installed from sources, including their dependencies, using [Poetry](https://python-poetry.org/):
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
spark.install_poetry_project("../my-poetry-project/", poetry_python="../venv-poetry/bin/python")
```
## Example
This example uses `install_pip_package` in a Spark standalone cluster.
First checkout the example code:
```shell
git clone https://github.com/G-Research/spark-extension.git
cd spark-extension/examples/python-deps
```
Build a Docker image based on the official Spark release:
```shell
docker build -t spark-extension-example-docker .
```
Start the example Spark standalone cluster consisting of a Spark master and one worker:
```shell
docker compose -f docker-compose.yml up -d
```
Run the `example.py` Spark application on the example cluster:
```shell
docker exec spark-master spark-submit --master spark://master:7077 --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 /example/example.py
```
The `--packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5` argument
tells `spark-submit` to add the `spark-extension` Maven package to the Spark job.
Alternatively, install the `pyspark-extension` PyPi package via `pip install` and remove the `--packages` argument from `spark-submit`:
```shell
docker exec spark-master pip install --user pyspark_extension==2.11.1.3.5
docker exec spark-master spark-submit --master spark://master:7077 /example/example.py
```
This output proves that PySpark could call into the function `func`, wich only works when Pandas and PyArrow are installed:
```
+---+
| id|
+---+
| 0|
| 1|
| 2|
+---+
```
Test that `spark.install_pip_package("pandas", "pyarrow")` is really required by this example by removing this line from `example.py` …
```diff
from pyspark.sql import SparkSession
def main():
spark = SparkSession.builder.appName("spark_app").getOrCreate()
def func(df):
return df
from gresearch.spark import install_pip_package
- spark.install_pip_package("pandas", "pyarrow")
spark.range(0, 3, 1, 5).mapInPandas(func, "id long").show()
if __name__ == "__main__":
main()
```
… and running the `spark-submit` command again. The example does not work anymore,
because the Pandas and PyArrow packages are missing from the driver:
```
Traceback (most recent call last):
File "/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/utils.py", line 27, in require_minimum_pandas_version
ModuleNotFoundError: No module named 'pandas'
```
Finally, shutdown the example cluster:
```shell
docker compose -f docker-compose.yml down
```
## Known Issues
Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
================================================
FILE: README.md
================================================
# Spark Extension
This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:
**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between
two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.
**[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing
a **sorted** iterator for each group. Similar to `Dataset.groupByKey.flatMapGroups`, but with order guarantees
for the iterator.
**[Histogram](HISTOGRAM.md) [[*]](#spark-connect-server):** A `histogram` transformation that computes the histogram DataFrame for a value column.
**[Global Row Number](ROW_NUMBER.md) [[*]](#spark-connect-server):** A `withRowNumbers` transformation that provides the global row number w.r.t.
the current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which
requires a window spec, this transformation provides the row number across the entire Dataset without scaling problems.
**[Partitioned Writing](PARTITIONING.md):** The `writePartitionedBy` action writes your `Dataset` partitioned and
efficiently laid out with a single operation.
**[Inspect Parquet files](PARQUET.md) [[*]](#spark-connect-server):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)
or [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.
This simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.
**[Install Python packages into PySpark job](PYSPARK-DEPS.md) [[*]](#spark-connect-server):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
# using PIP
spark.install_pip_package("pandas==1.4.3", "pyarrow")
spark.install_pip_package("-r", "requirements.txt")
# using Poetry
spark.install_poetry_project("../my-poetry-project/", poetry_python="../venv-poetry/bin/python")
```
**[Fluent method call](CONDITIONAL.md):** `T.call(transformation: T => R): R`: Turns a transformation `T => R`,
that is not part of `T` into a fluent method call on `T`. This allows writing fluent code like:
```scala
import uk.co.gresearch._
i.doThis()
.doThat()
.call(transformation)
.doMore()
```
**[Fluent conditional method call](CONDITIONAL.md):** `T.when(condition: Boolean).call(transformation: T => T): T`:
Perform a transformation fluently only if the given condition is true.
This allows writing fluent code like:
```scala
import uk.co.gresearch._
i.doThis()
.doThat()
.when(condition).call(transformation)
.doMore()
```
**[Shortcut for groupBy.as](https://github.com/G-Research/spark-extension/pull/213#issue-2032837105)**: Calling `Dataset.groupBy(Column*).as[K, T]`
should be preferred over calling `Dataset.groupByKey(V => K)` whenever possible. The former allows Catalyst to exploit
existing partitioning and ordering of the Dataset, while the latter hides from Catalyst which columns are used to create the keys.
This can have a significant performance penalty.
Details:
The new column-expression-based `groupByKey[K](Column*)` method makes it easier to group by a column expression key. Instead of
ds.groupBy($"id").as[Int, V]
use:
ds.groupByKey[Int]($"id")
**Backticks:** `backticks(string: String, strings: String*): String)`: Encloses the given column name with backticks (`` ` ``) when needed.
This is a handy way to ensure column names with special characters like dots (`.`) work with `col()` or `select()`.
**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.
This is equivalent to calling `count(when(e.isNull, lit(1)))`.
**.Net DateTime.Ticks[[*]](#spark-connect-server):** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
Available methods:
```scala
// Scala
dotNetTicksToTimestamp(Column): Column // returns timestamp as TimestampType
dotNetTicksToUnixEpoch(Column): Column // returns Unix epoch seconds as DecimalType
dotNetTicksToUnixEpochNanos(Column): Column // returns Unix epoch nanoseconds as LongType
```
The reverse is provided by (all return `LongType` .Net ticks):
```scala
// Scala
timestampToDotNetTicks(Column): Column
unixEpochToDotNetTicks(Column): Column
unixEpochNanosToDotNetTicks(Column): Column
```
These methods are also available in Python:
```python
# Python
dotnet_ticks_to_timestamp(column_or_name) # returns timestamp as TimestampType
dotnet_ticks_to_unix_epoch(column_or_name) # returns Unix epoch seconds as DecimalType
dotnet_ticks_to_unix_epoch_nanos(column_or_name) # returns Unix epoch nanoseconds as LongType
timestamp_to_dotnet_ticks(column_or_name)
unix_epoch_to_dotnet_ticks(column_or_name)
unix_epoch_nanos_to_dotnet_ticks(column_or_name)
```
**Spark temporary directory[[*]](#spark-connect-server)**: Create a temporary directory that will be removed on Spark application shutdown.
Examples:
Scala:
```scala
import uk.co.gresearch.spark.createTemporaryDir
val dir = createTemporaryDir("prefix")
```
Python:
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
dir = spark.create_temporary_dir("prefix")
```
**Spark job description[[*]](#spark-connect-server):** Set Spark job description for all Spark jobs within a context.
Examples:
```scala
import uk.co.gresearch.spark._
implicit val session: SparkSession = spark
withJobDescription("parquet file") {
val df = spark.read.parquet("data.parquet")
val count = appendJobDescription("count") {
df.count
}
appendJobDescription("write") {
df.write.csv("data.csv")
}
}
```
| Without job description | With job description |
|:---:|:---:|
|  |  |
Note that setting a description in one thread while calling the action (e.g. `.count`) in a different thread
does not work, unless the different thread is spawned from the current thread _after_ the description has been set.
Working example with parallel collections:
```scala
import java.util.concurrent.ForkJoinPool
import scala.collection.parallel.CollectionConverters.seqIsParallelizable
import scala.collection.parallel.ForkJoinTaskSupport
val files = Seq("data1.csv", "data2.csv").par
val counts = withJobDescription("Counting rows") {
// new thread pool required to spawn new threads from this thread
// so that the job description is actually used
files.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool())
files.map(filename => spark.read.csv(filename).count).sum
}(spark)
```
## Using Spark Extension
The `spark-extension` package is available for all Spark 3.2, 3.3, 3.4 and 3.5 versions.
The package version has the following semantics: `spark-extension_{SCALA_COMPAT_VERSION}-{VERSION}-{SPARK_COMPAT_VERSION}`:
- `SCALA_COMPAT_VERSION`: Scala binary compatibility (minor) version. Available are `2.12` and `2.13`.
- `SPARK_COMPAT_VERSION`: Apache Spark binary compatibility (minor) version. Available are `3.2`, `3.3`, `3.4`, `3.5` and `4.0`.
- `VERSION`: The package version, e.g. `2.14.0`.
### SBT
Add this line to your `build.sbt` file:
```sbt
libraryDependencies += "uk.co.gresearch.spark" %% "spark-extension" % "2.15.0-3.5"
```
### Maven
Add this dependency to your `pom.xml` file:
```xml
uk.co.gresearch.sparkspark-extension_2.122.15.0-3.5
```
### Gradle
Add this dependency to your `build.gradle` file:
```groovy
dependencies {
implementation "uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5"
}
```
### Spark Submit
Submit your Spark app with the Spark Extension dependency (version ≥1.1.0) as follows:
```shell script
spark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 [jar]
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark version.
### Spark Shell
Launch a Spark Shell with the Spark Extension dependency (version ≥1.1.0) as follows:
```shell script
spark-shell --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark Shell version.
### Python
#### PySpark API
Start a PySpark session with the Spark Extension dependency (version ≥1.1.0) as follows:
```python
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.config("spark.jars.packages", "uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5") \
.getOrCreate()
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your PySpark version.
#### PySpark REPL
Launch the Python Spark REPL with the Spark Extension dependency (version ≥1.1.0) as follows:
```shell script
pyspark --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your PySpark version.
#### PySpark `spark-submit`
Run your Python scripts that use PySpark via `spark-submit`:
```shell script
spark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 [script.py]
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark version.
#### PyPi package (local Spark cluster only)
You may want to install the `pyspark-extension` python package from PyPi into your development environment.
This provides you code completion, typing and test capabilities during your development phase.
Running your Python application on a Spark cluster will still require one of the above ways
to add the Scala package to the Spark environment.
```shell script
pip install pyspark-extension==2.15.0.3.5
```
Note: Pick the right Spark version (here 3.5) depending on your PySpark version.
### Your favorite Data Science notebook
There are plenty of [Data Science notebooks](https://datasciencenotebook.org/) around. To use this library,
add **a jar dependency** to your notebook using these **Maven coordinates**:
uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5
Or [download the jar](https://mvnrepository.com/artifact/uk.co.gresearch.spark/spark-extension) and place it
on a filesystem where it is accessible by the notebook, and reference that jar file directly.
Check the documentation of your favorite notebook to learn how to add jars to your Spark environment.
## Known issues
### Spark Connect Server
Most features are not supported **in Python** in conjunction with a [Spark Connect server](https://spark.apache.org/docs/latest/spark-connect-overview.html).
This also holds for Databricks Runtime environment 13.x and above. Details can be found [in this blog](https://semyonsinchenko.github.io/ssinchenko/post/how-databricks-14x-breaks-3dparty-compatibility/).
Calling any of those features when connected to a Spark Connect server will raise this error:
This feature is not supported for Spark Connect.
Use a classic connection to a Spark cluster instead.
## Build
You can build this project against different versions of Spark and Scala.
### Switch Spark and Scala version
If you want to build for a Spark or Scala version different to what is defined in the `pom.xml` file, then run
```shell script
sh set-version.sh [SPARK-VERSION] [SCALA-VERSION]
```
For example, switch to Spark 3.5.0 and Scala 2.13.8 by running `sh set-version.sh 3.5.0 2.13.8`.
### Build the Scala project
Then execute `mvn package` to create a jar from the sources. It can be found in `target/`.
## Testing
Run the Scala tests via `mvn test`.
### Setup Python environment
In order to run the Python tests, setup a Python environment as follows:
```shell script
virtualenv -p python3 venv
source venv/bin/activate
pip install python/[test]
```
### Run Python tests
Run the Python tests via `env PYTHONPATH=python/test python -m pytest python/test`.
### Build Python package
Run the following commands in the project root directory to create a whl from the sources:
```shell script
pip install build
python -m build python/
```
It can be found in `python/dist/`.
## Publications
- ***Guaranteeing in-partition order for partitioned-writing in Apache Spark**, Enrico Minack, 20/01/2023*: https://www.gresearch.com/blog/article/guaranteeing-in-partition-order-for-partitioned-writing-in-apache-spark/
- ***Un-pivot, sorted groups and many bug fixes: Celebrating the first Spark 3.4 release**, Enrico Minack, 21/03/2023*: https://www.gresearch.com/blog/article/un-pivot-sorted-groups-and-many-bug-fixes-celebrating-the-first-spark-3-4-release/
- ***A PySpark bug makes co-grouping with window function partition-key-order-sensitive**, Enrico Minack, 29/03/2023*: https://www.gresearch.com/blog/article/a-pyspark-bug-makes-co-grouping-with-window-function-partition-key-order-sensitive/
- ***Spark’s groupByKey should be avoided – and here’s why**, Enrico Minack, 13/06/2023*: https://www.gresearch.com/blog/article/sparks-groupbykey-should-be-avoided-and-heres-why/
- ***Inspecting Parquet files with Spark**, Enrico Minack, 28/07/2023*: https://www.gresearch.com/blog/article/parquet-files-know-your-scaling-limits/
- ***Enhancing Spark’s UI with Job Descriptions**, Enrico Minack, 12/12/2023*: https://www.gresearch.com/blog/article/enhancing-sparks-ui-with-job-descriptions/
- ***PySpark apps with dependencies: Managing Python dependencies in code**, Enrico Minack, 24/01/2024*: https://www.gresearch.com/news/pyspark-apps-with-dependencies-managing-python-dependencies-in-code/
- ***Observing Spark Aggregates: Cheap Metrics from Datasets**, Enrico Minack, 06/02/2024*: https://www.gresearch.com/news/observing-spark-aggregates-cheap-metrics-from-datasets-2/
## Security
Please see our [security policy](https://github.com/G-Research/spark-extension/blob/master/SECURITY.md) for details on reporting security vulnerabilities.
================================================
FILE: RELEASE.md
================================================
# Releasing Spark Extension
This provides instructions on how to release a version of `spark-extension`. We release this library
for a number of Spark and Scala environments, but all from the same git tag. Release for the environment
that is set in the `pom.xml` and create a tag. On success, release from that tag for all other environments
as described below.
Use the `release.sh` script to test and release all versions. Or execute the following steps manually.
## Testing master for all environments
The following steps release a snapshot and test it. Test all versions listed [further down](#releasing-master-for-other-environments).
- Set the version with `./set-version.sh`, e.g. `./set-version.sh 3.4.0 2.12.17`
- Release a snapshot (make sure the version in the `pom.xml` file ends with `SNAPSHOT`): `mvn clean deploy`
- Test the released snapshot: `./test-release.sh`
## Releasing from master
Follow this procedure to release a new version:
- Add a new entry to `CHANGELOG.md` listing all notable changes of this release.
Use the heading `## [VERSION] - YYYY-MM-dd`, e.g. `## [1.1.0] - 2020-03-12`.
- Remove the `-SNAPSHOT` suffix from the version, e.g. `./set-version 1.1.0`.
- Update the versions in the `README.md` and `python/README.md` file to the version of your `pom.xml` to reflect the latest version,
e.g. replace all `1.0.0-3.1` with `1.1.0-3.1` and `1.0.0.3.1` with `1.1.0.3.1`, respectively.
- Commit the change to your local git repository, use a commit message like `Releasing 1.1.0`. Do not push to github yet.
- Tag that commit with a version tag like `v1.1.0` and message like `Release v1.1.0`. Do not push to github yet.
- Release the version with `mvn clean deploy`. This will be put into a staging repository and not automatically released (due to `false` in your [`pom.xml`](pom.xml) file).
- Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that. If you are happy with everything:
- Push the commit and tag to origin.
- Release the package with `mvn nexus-staging:release`.
- Bump the version to the next [minor version](https://semver.org/) and append the `-SNAPSHOT` suffix again: `./set-version 1.2.0-SNAPSHOT`.
- Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.2.0`.
- Push all local commits to origin.
- Otherwise drop it with `mvn nexus-staging:drop`. Remove the last two commits from your local history.
## Releasing master for other environments
Once you have released the new version, release from the same tag for all other Spark and Scala environments as well:
- Release for these environments, one of these has been released above, that should be the tagged version:
|Spark|Scala|
|:----|:----|
|3.2 |2.12.15 and 2.13.5|
|3.3 |2.12.15 and 2.13.8|
|3.4 |2.12.17 and 2.13.8|
|3.5 |2.12.17 and 2.13.8|
- Always use the latest Spark version per Spark minor version
- Release process:
- Checkout the release tag, e.g. `git checkout v1.0.0`
- Set the version in the `pom.xml` file via `set-version.sh`, e.g. `./set-version.sh 3.4.0 2.12.17`
- Review the `pom.xml` file changes: `git diff pom.xml`
- Release the version with `mvn clean deploy`
- Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that.
- If you are happy with everything, release the package with `mvn nexus-staging:release`.
- Otherwise drop it with `mvn nexus-staging:drop`.
- Revert the changes done to the `pom.xml` file: `git checkout pom.xml`
## Releasing a bug-fix version
A bug-fix version needs to be released from a [minor-version branch](https://semver.org/), e.g. `branch-1.1`.
### Create a bug-fix branch
If there is no bug-fix branch yet, create it:
- Create such a branch from the respective [minor-version tag](https://semver.org/), e.g. create minor version branch `branch-1.1` from tag `v1.1.0`.
- Bump the version to the next [patch version](https://semver.org/) in `pom.xml` and append the `-SNAPSHOT` suffix again, e.g. `1.1.0` → `1.1.1-SNAPSHOT`.
- Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.1.1`.
- Push this commit to origin.
Merge your bug fixes into this branch as you would normally do for master, use PRs for that.
### Release from a bug-fix branch
This is very similar to [releasing from master](#releasing-from-master),
but the version increment occurs on [patch level](https://semver.org/):
- Add a new entry to `CHANGELOG.md` listing all notable changes of this release.
Use the heading `## [VERSION] - YYYY-MM-dd`, e.g. `## [1.1.1] - 2020-03-12`.
- Remove the `-SNAPSHOT` suffix from the version, e.g. `./set-version 1.1.1`.
- Update the versions in the `README.md` and `python/README.md` file to the version of your `pom.xml` to reflect the latest version,
e.g. replace all `1.1.0-3.1` with `1.1.1-3.1` and `1.1.0.3.1` with `1.1.1.3.1`, respectively.
- Commit the change to your local git repository, use a commit message like `Releasing 1.1.1`. Do not push to github yet.
- Tag that commit with a version tag like `v1.1.1` and message like `Release v1.1.1`. Do not push to github yet.
- Release the version with `mvn clean deploy`. This will be put into a staging repository and not automatically released (due to `false` in your [`pom.xml`](pom.xml) file).
- Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that. If you are happy with everything:
- Push the commit and tag to origin.
- Release the package with `mvn nexus-staging:release`.
- Bump the version to the next [patch version](https://semver.org/) and append the `-SNAPSHOT` suffix again: `./set-version 1.1.2-SNAPSHOT`.
- Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.1.2`.
- Push all local commits to origin.
- Otherwise drop it with `mvn nexus-staging:drop`. Remove the last two commits from your local history.
Consider releasing the bug-fix version for other environments as well. See [above](#releasing-master-for-other-environments) section for details.
================================================
FILE: ROW_NUMBER.md
================================================
# Global Row Number
Spark provides the [SQL function `row_number`](https://spark.apache.org/docs/latest/api/sql/index.html#row_number),
which assigns each row a consecutive number, starting from 1. This function works on a [Window](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/expressions/Window.html).
Assigning a row number over the entire Dataset will load the entire dataset into a single partition / executor.
This does not scale.
Spark extensions provide the `Dataset` transformation `withRowNumbers`, which assigns a global row number while scaling:
```scala
val df = Seq((1, "one"), (2, "TWO"), (2, "two"), (3, "three")).toDF("id", "value")
df.show()
// +---+-----+
// | id|value|
// +---+-----+
// | 1| one|
// | 2| TWO|
// | 2| two|
// | 3|three|
// +---+-----+
import uk.co.gresearch.spark._
df.withRowNumbers().show()
// +---+-----+----------+
// | id|value|row_number|
// +---+-----+----------+
// | 1| one| 1|
// | 2| two| 2|
// | 2| TWO| 3|
// | 3|three| 4|
// +---+-----+----------+
```
In Java:
```java
import uk.co.gresearch.spark.RowNumbers;
RowNumbers.of(df).show();
// +---+-----+----------+
// | id|value|row_number|
// +---+-----+----------+
// | 1| one| 1|
// | 2| two| 2|
// | 2| TWO| 3|
// | 3|three| 4|
// +---+-----+----------+
```
In Python:
```python
import gresearch.spark
df.with_row_numbers().show()
# +---+-----+----------+
# | id|value|row_number|
# +---+-----+----------+
# | 1| one| 1|
# | 2| two| 2|
# | 2| TWO| 3|
# | 3|three| 4|
# +---+-----+----------+
```
## Row number order
Row numbers are assigned in the current order of the Dataset. If you want a specific order, provide columns as follows:
```scala
df.withRowNumbers($"id".desc, $"value").show()
// +---+-----+----------+
// | id|value|row_number|
// +---+-----+----------+
// | 3|three| 1|
// | 2| TWO| 2|
// | 2| two| 3|
// | 1| one| 4|
// +---+-----+----------+
```
In Java:
```java
RowNumbers.withOrderColumns(df.col("id").desc(), df.col("value")).of(df).show();
// +---+-----+----------+
// | id|value|row_number|
// +---+-----+----------+
// | 3|three| 1|
// | 2| TWO| 2|
// | 2| two| 3|
// | 1| one| 4|
// +---+-----+----------+
```
In Python:
```python
df.with_row_numbers(order=[df.id.desc(), df.value]).show()
# +---+-----+----------+
# | id|value|row_number|
# +---+-----+----------+
# | 3|three| 1|
# | 2| TWO| 2|
# | 2| two| 3|
# | 1| one| 4|
# +---+-----+----------+
```
## Row number column name
The column name that contains the row number can be changed by providing the `rowNumberColumnName` argument:
```scala
df.withRowNumbers(rowNumberColumnName="row").show()
// +---+-----+---+
// | id|value|row|
// +---+-----+---+
// | 1| one| 1|
// | 2| TWO| 2|
// | 2| two| 3|
// | 3|three| 4|
// +---+-----+---+
```
In Java:
```java
RowNumbers.withRowNumberColumnName("row").of(df).show();
// +---+-----+---+
// | id|value|row|
// +---+-----+---+
// | 1| one| 1|
// | 2| TWO| 2|
// | 2| two| 3|
// | 3|three| 4|
// +---+-----+---+
```
In Python:
```python
df.with_row_numbers(row_number_column_name='row').show()
# +---+-----+---+
# | id|value|row|
# +---+-----+---+
# | 1| one| 1|
# | 2| TWO| 2|
# | 2| two| 3|
# | 3|three| 4|
# +---+-----+---+
```
## Cached / persisted intermediate Dataset
The `withRowNumbers` transformation requires the input Dataset to be
[cached](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#cache():Dataset.this.type) /
[persisted](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#persist(newLevel:org.apache.spark.storage.StorageLevel):Dataset.this.type),
after adding an intermediate column. You can specify the level of persistence through the `storageLevel` parameter.
```scala
import org.apache.spark.storage.StorageLevel
val dfWithRowNumbers = df.withRowNumbers(storageLevel=StorageLevel.DISK_ONLY)
```
In Java:
```java
import org.apache.spark.storage.StorageLevel;
Dataset dfWithRowNumbers = RowNumbers.withStorageLevel(StorageLevel.DISK_ONLY()).of(df);
```
In Python:
```python
from pyspark.storagelevel import StorageLevel
df_with_row_numbers = df.with_row_numbers(storage_level=StorageLevel.DISK_ONLY)
```
## Un-persist intermediate Dataset
If you want control over when to un-persist this intermediate Dataset, you can provide an `UnpersistHandle` and call it
when you are done with the result Dataset:
```scala
import uk.co.gresearch.spark.UnpersistHandle
val unpersist = UnpersistHandle()
val dfWithRowNumbers = df.withRowNumbers(unpersistHandle=unpersist);
// after you are done with dfWithRowNumbers you may want to call unpersist()
unpersist(blocking=false)
```
In Java:
```java
import uk.co.gresearch.spark.UnpersistHandle;
UnpersistHandle unpersist = new UnpersistHandle();
Dataset dfWithRowNumbers = RowNumbers.withUnpersistHandle(unpersist).of(df);
// after you are done with dfWithRowNumbers you may want to call unpersist()
unpersist.apply(true);
```
In Python:
```python
unpersist = spark.unpersist_handle()
df_with_row_numbers = df.with_row_numbers(unpersist_handle=unpersist)
# after you are done with df_with_row_numbers you may want to call unpersist()
unpersist(blocking=True)
```
## Spark warning
You will recognize that Spark logs the following warning:
```
WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
```
This warning is unavoidable, because `withRowNumbers` has to pull information about the initial partitions into a single partition.
Fortunately, there are only 12 Bytes per input partition required, so this amount of data usually fits into a single partition and the warning can safely be ignored.
## Known issues
Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
================================================
FILE: SECURITY.md
================================================
# Security and Coordinated Vulnerability Disclosure Policy
This project appreciates and encourages coordinated disclosure of security vulnerabilities. We prefer that you use the GitHub reporting mechanism to privately report vulnerabilities. Under the main repository's security tab, click "Report a vulnerability" to open the advisory form.
If you are unable to report it via GitHub, have received no response after repeated attempts, or have other security related questions, please contact security@gr-oss.io and mention this project in the subject line.
================================================
FILE: build-whl.sh
================================================
#!/bin/bash
set -eo pipefail
base=$(cd "$(dirname "$0")"; pwd)
version=$(grep --max-count=1 ".*" "$base/pom.xml" | sed -E -e "s/\s*<[^>]+>//g")
artifact_id=$(grep --max-count=1 ".*" "$base/pom.xml" | sed -E -e "s/\s*<[^>]+>//g")
rm -rf "$base/python/pyspark/jars/$artifact_id-*.jar"
pip install build
python -m build "$base/python/"
# check for missing modules in whl file
pyversion=${version/SNAPSHOT/dev0}
pyversion=${pyversion//-/.}
missing="$(diff <(cd "$base/python"; find gresearch -type f | grep -v ".pyc$" | sort) <(unzip -l "$base/python/dist/pyspark_extension-${pyversion}-*.whl" | tail -n +4 | head -n -2 | sed -E -e "s/^ +//" -e "s/ +/ /g" | cut -d " " -f 4- | sort) | grep "^<" || true)"
if [ -n "$missing" ]
then
echo "These files are missing from the whl file:"
echo "$missing"
exit 1
fi
jars=$(unzip -l "$base/python/dist/pyspark_extension-${pyversion}-*.whl" | grep ".jar" | wc -l)
if [ $jars -ne 1 ]
then
echo "Expected exactly one jar in whl file, but $jars found!"
exit 1
fi
================================================
FILE: bump-version.sh
================================================
#!/bin/bash
#
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Script to prepare release, see RELEASE.md for details
set -e -o pipefail
# check for clean git status
readarray -t git_status < <(git status -s --untracked-files=no 2>/dev/null)
if [ ${#git_status[@]} -gt 0 ]
then
echo "There are pending git changes:"
for (( i=0; i<${#git_status[@]}; i++ )); do echo "${git_status[$i]}" ; done
exit 1
fi
function next_version {
local version=$1
local branch=$2
patch=${version/*./}
majmin=${version%.${patch}}
if [[ $branch == "master" ]]
then
# minor version bump
if [[ $version != *".0" ]]
then
echo "version is patch version, should be M.m.0: $version" >&2
exit 1
fi
maj=${version/.*/}
min=${majmin#${maj}.}
next=${maj}.$((min+1)).0
echo "$next"
else
# patch version bump
next=${majmin}.$((patch+1))
echo "$next"
fi
}
# get release and next version
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g")
pkg_version="${version/-*/}"
branch=$(git rev-parse --abbrev-ref HEAD)
next_pkg_version="$(next_version "$pkg_version" "$branch")"
# bump the version
echo "Bump version to $next_pkg_version"
./set-version.sh $next_pkg_version-SNAPSHOT
# commit changes to local repo
echo
echo "Committing release to local git"
git commit -a -m "Post-release version bump to $next_pkg_version"
git show HEAD
echo
# push version bump to origin
echo "Press to push commit to origin"
read
echo "Pushing release commit to origin"
git push origin "master"
echo
================================================
FILE: examples/python-deps/Dockerfile
================================================
FROM apache/spark:3.5.0
ENV PATH="${PATH}:/opt/spark/bin"
USER root
RUN mkdir -p /home/spark; chown spark:spark /home/spark
USER spark
================================================
FILE: examples/python-deps/docker-compose.yml
================================================
version: "3"
services:
master:
container_name: spark-master
image: spark-extension-example-docker
command: /opt/spark/bin/spark-class org.apache.spark.deploy.master.Master -h master
environment:
MASTER: spark://master:7077
SPARK_PUBLIC_DNS: localhost
SPARK_MASTER_WEBUI_PORT: 8080
PYSPARK_PYTHON: python${PYTHON_VERSION:-3.8}
PYSPARK_DRIVER_PYTHON: python${PYTHON_VERSION:-3.8}
expose:
- 7077
ports:
- 4040:4040
- 8080:8080
volumes:
- ./:/example
worker:
container_name: spark-worker
image: spark-extension-example-docker
command: /opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker spark://master:7077
environment:
SPARK_WORKER_CORES: 1
SPARK_WORKER_MEMORY: 1g
SPARK_WORKER_PORT: 8881
SPARK_WORKER_WEBUI_PORT: 8081
SPARK_PUBLIC_DNS: localhost
links:
- master
ports:
- 8081:8081
================================================
FILE: examples/python-deps/example.py
================================================
from pyspark.sql import SparkSession
def main():
spark = SparkSession.builder.appName("spark_app").getOrCreate()
def func(df):
return df
from gresearch.spark import install_pip_package
spark.install_pip_package("pandas", "pyarrow")
spark.range(0, 3, 1, 5).mapInPandas(func, "id long").show()
if __name__ == "__main__":
main()
================================================
FILE: pom.xml
================================================
4.0.0uk.co.gresearch.sparkspark-extension_2.132.16.0-3.5-SNAPSHOTSpark ExtensionA library that provides useful extensions to Apache Spark.2020https://github.com/G-ResearchApache 2.0 Licensehttp://www.apache.org/licenses/LICENSE-2.0.htmlreposcm:git:git://github.com/g-research/spark-extension.gitscm:git:ssh://github.com:g-research/spark-extension.githttps://github.com/g-research/spark-extension/tree/${project.scm.tag}masterEnricoMiEnrico Minackgithub@enrico.minack.devGitHub Issueshttps://github.com/G-Research/spark-extension/issues1.8${java.version}${java.version}UTF-8${project.version}2138${scala.major.version}.${scala.minor.version}${scala.compat.version}.${scala.patch.version}351${spark.major.version}.${spark.minor.version}${spark.compat.version}.${spark.patch.version}org.scala-langscala-library${scala.version}org.apache.sparkspark-sql_${scala.compat.version}${spark.version}providedorg.apache.parquet*io.airliftaircompressororg.xerial.snappysnappy-javaorg.slf4jslf4j-apiorg.apache.sparkspark-catalyst_${scala.compat.version}${spark.version}providedorg.apache.sparkspark-hive_${scala.compat.version}${spark.version}providedorg.apache.parquetparquet-hadoop1.16.0commons-poolcommons-pooljavax.annotationjavax.annotation-apicom.github.lubenzstd-jniprovidedcom.github.scoptscopt_${scala.compat.version}4.1.0org.apache.sparkspark-catalyst_${scala.compat.version}${spark.version}teststestjunitjunit4.13.2testorg.scalatestscalatest_${scala.compat.version}3.3.0-SNAP4testorg.scalatestplusscalatestplus-junit_${scala.compat.version}1.0.0-M2testorg.apache.parquetparquet-hadoop1.16.0teststest**centralMaven Centraldefaulthttps://repo1.maven.org/maven2trueneverfalseapache releasesApache Releaseshttps://repository.apache.org/content/repositories/releases/trueneverfalseapache snapshotsApache Snapshotshttps://repository.apache.org/snapshots/falsetruedailyapache release candidateApache staginghttps://repository.apache.org/content/repositories/orgapachespark-1478/falsefalsesrc/main/scalasrc/test/javapythongresearch/**/*.pysrc/test/resourcesorg.codehaus.mojobuild-helper-maven-plugin3.5.0spark-version-sourcesgenerate-sourcesadd-sourcesrc/main/scala-spark-${spark.compat.version}spark-version-test-sourcesgenerate-test-sourcesadd-test-sourcesrc/test/scala-spark-${spark.major.version}org.codehaus.mojoproperties-maven-plugin1.2.1generate-resourceswrite-project-properties${project.build.outputDirectory}/spark-extension-build.propertiesorg.scala-toolsmaven-scala-plugin2.15.2compiletestCompile-dependencyfile${project.build.directory}/.scala_dependenciesorg.apache.maven.pluginsmaven-jar-plugin3.3.0trueuk.co.gresearch.spark.diff.Appcom.diffplug.spotlessspotless-maven-plugin2.30.03.7.17${project.basedir}/.scalafmt.confspotless-checkcompilecheckorg.apache.maven.pluginsmaven-surefire-plugin3.1.2false**/*Tests.class**/*Suite.classorg.apache.maven.pluginsmaven-source-plugin3.3.0attach-sourcesjar-no-forknet.alchim31.mavenscala-maven-plugin4.8.1attach-javadocsdoc-jarorg.apache.maven.pluginsmaven-failsafe-plugin3.3.0${project.build.directory}/${project.build.finalName}.jar**/*Tests.class**/*Suite.class${project.build.directory}/surefire-integration-reports/integration-testverifytrueorg.sonatype.centralcentral-publishing-maven-plugin0.8.0truecentraltruepublishedorg.apache.maven.pluginsmaven-gpg-plugin3.1.0sign-artifactsverifysignorg.apache.maven.pluginsmaven-surefire-report-plugin3.1.2
================================================
FILE: python/README.md
================================================
# Spark Extension
This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:
**[Diff](https://github.com/G-Research/spark-extension/blob/v2.13.0/DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between
two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.
**[Histogram](https://github.com/G-Research/spark-extension/blob/v2.13.0/HISTOGRAM.md):** A `histogram` transformation that computes the histogram DataFrame for a value column.
**[Global Row Number](https://github.com/G-Research/spark-extension/blob/v2.13.0/ROW_NUMBER.md):** A `withRowNumbers` transformation that provides the global row number w.r.t.
the current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which
requires a window spec, this transformation provides the row number across the entire Dataset without scaling problems.
**[Inspect Parquet files](https://github.com/G-Research/spark-extension/blob/v2.13.0/PARQUET.md):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)
or [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.
This simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.
**[Install Python packages into PySpark job](https://github.com/G-Research/spark-extension/blob/v2.13.0/PYSPARK-DEPS.md):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
# using PIP
spark.install_pip_package("pandas==1.4.3", "pyarrow")
spark.install_pip_package("-r", "requirements.txt")
# using Poetry
spark.install_poetry_project("../my-poetry-project/", poetry_python="../venv-poetry/bin/python")
```
**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.
This is equivalent to calling `count(when(e.isNull, lit(1)))`.
**.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
Available methods:
```python
dotnet_ticks_to_timestamp(column_or_name) # returns timestamp as TimestampType
dotnet_ticks_to_unix_epoch(column_or_name) # returns Unix epoch seconds as DecimalType
dotnet_ticks_to_unix_epoch_nanos(column_or_name) # returns Unix epoch nanoseconds as LongType
```
The reverse is provided by (all return `LongType` .Net ticks):
```python
timestamp_to_dotnet_ticks(column_or_name)
unix_epoch_to_dotnet_ticks(column_or_name)
unix_epoch_nanos_to_dotnet_ticks(column_or_name)
```
**Spark temporary directory**: Create a temporary directory that will be removed on Spark application shutdown.
Example:
```python
# noinspection PyUnresolvedReferences
from gresearch.spark import *
dir = spark.create_temporary_dir("prefix")
```
**Spark job description:** Set Spark job description for all Spark jobs within a context.
Example:
```python
from gresearch.spark import job_description, append_job_description
with job_description("parquet file"):
df = spark.read.parquet("data.parquet")
with append_job_description("count"):
count = df.count
with append_job_description("write"):
df.write.csv("data.csv")
```
For details, see the [README.md](https://github.com/G-Research/spark-extension#spark-extension) at the project homepage.
## Using Spark Extension
#### PyPi package (local Spark cluster only)
You may want to install the `pyspark-extension` python package from PyPi into your development environment.
This provides you code completion, typing and test capabilities during your development phase.
Running your Python application on a Spark cluster will still require one of the ways below
to add the Scala package to the Spark environment.
```shell script
pip install pyspark-extension==2.15.0.3.4
```
Note: Pick the right Spark version (here 3.4) depending on your PySpark version.
#### PySpark API
Start a PySpark session with the Spark Extension dependency (version ≥1.1.0) as follows:
```python
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.config("spark.jars.packages", "uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4") \
.getOrCreate()
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your PySpark version.
#### PySpark REPL
Launch the Python Spark REPL with the Spark Extension dependency (version ≥1.1.0) as follows:
```shell script
pyspark --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your PySpark version.
#### PySpark `spark-submit`
Run your Python scripts that use PySpark via `spark-submit`:
```shell script
spark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4 [script.py]
```
Note: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your Spark version.
### Your favorite Data Science notebook
There are plenty of [Data Science notebooks](https://datasciencenotebook.org/) around. To use this library,
add **a jar dependency** to your notebook using these **Maven coordinates**:
uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4
Or [download the jar](https://mvnrepository.com/artifact/uk.co.gresearch.spark/spark-extension) and place it
on a filesystem where it is accessible by the notebook, and reference that jar file directly.
Check the documentation of your favorite notebook to learn how to add jars to your Spark environment.
================================================
FILE: python/gresearch/__init__.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: python/gresearch/spark/__init__.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import shutil
import subprocess
import sys
import tempfile
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Union, List, Optional, Mapping, Iterable, TYPE_CHECKING
from py4j.java_gateway import JVMView, JavaObject
from pyspark import __version__
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.sql import DataFrame, DataFrameReader, SQLContext
from pyspark.sql.column import Column
from pyspark.sql.context import SQLContext
from pyspark import SparkConf
from pyspark.sql.functions import col, count, lit, when
from pyspark.sql.session import SparkSession
from pyspark.storagelevel import StorageLevel
if __version__.startswith('4.'):
from pyspark.sql.classic.column import _to_java_column
else:
from pyspark.sql.column import _to_java_column
try:
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
has_connect = True
except ImportError:
has_connect = False
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName
_java_pkg_is_installed: Optional[bool] = None
_column_types = (Column,)
_dataframe_types = (DataFrame,)
if has_connect:
_column_types += (ConnectColumn, )
_dataframe_types += (ConnectDataFrame, )
_column_types_and_str = (str,) + _column_types
def _is_column(obj: Any) -> bool:
return isinstance(obj, _column_types)
def _is_column_or_str(obj: Any) -> bool:
return isinstance(obj, _column_types_and_str)
def _is_dataframe(obj: Any) -> bool:
return isinstance(obj, _dataframe_types)
def _check_java_pkg_is_installed(jvm: JVMView) -> bool:
"""Check that the Java / Scala package is installed."""
try:
jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").VersionString()
return True
except TypeError as e:
print(e.args)
return False
except:
# any other exception indicate some problem, be safe and do not fail fast here
return True
def _get_jvm(obj: Any) -> JVMView:
"""
Provides easy access to the JVMView provided by Spark, and raises meaningful error message if that is not available.
Also checks that the Java / Scala package is accessible via this JVMView.
"""
if obj is None:
if SparkContext._active_spark_context is None:
raise RuntimeError("This method must be called inside an active Spark session")
else:
raise ValueError("Cannot provide access to JVM from None")
if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)):
raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. '
'https://github.com/G-Research/spark-extension#spark-connect-server')
if isinstance(obj, DataFrame):
jvm = _get_jvm(obj._sc)
elif isinstance(obj, DataFrameReader):
jvm = _get_jvm(obj._spark)
elif isinstance(obj, SparkSession):
jvm = _get_jvm(obj.sparkContext)
elif isinstance(obj, (SparkContext, SQLContext)):
jvm = obj._jvm
else:
raise RuntimeError(f'Unsupported class: {type(obj)}')
global _java_pkg_is_installed
if _java_pkg_is_installed is None:
_java_pkg_is_installed = _check_java_pkg_is_installed(jvm)
if not _java_pkg_is_installed:
raise RuntimeError("Java / Scala package not found! You need to add the Maven spark-extension package "
"to your PySpark environment: https://github.com/G-Research/spark-extension#python")
return jvm
def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:
array = jvm.java.util.ArrayList(list)
return jvm.scala.collection.JavaConverters.asScalaIteratorConverter(array.iterator()).asScala().toSeq()
def _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject:
return jvm.scala.collection.JavaConverters.mapAsScalaMap(map)
def backticks(*name_parts: str) -> str:
for np in name_parts:
assert isinstance(np, str), np
return '.'.join([f'`{part}`'
if '.' in part and not part.startswith('`') and not part.endswith('`')
else part
for part in name_parts])
def distinct_prefix_for(existing: List[str]) -> str:
assert isinstance(existing, Iterable)
for e in existing:
assert isinstance(e, str), e
# count number of suffix _ for each existing column name
length = 1
if existing:
length = max([len(name) - len(name.lstrip('_')) for name in existing]) + 1
# return string with one more _ than that
return '_' * length
def handle_configured_case_sensitivity(column_name: str, case_sensitive: bool) -> str:
"""
Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is
deactivated, it lower-cases the given column name and no-ops otherwise.
"""
assert isinstance(column_name, str), column_name
assert isinstance(case_sensitive, bool), case_sensitive
if case_sensitive:
return column_name
return column_name.lower()
def list_contains_case_sensitivity(column_names: Iterable[str], columnName: str, case_sensitive: bool) -> bool:
assert isinstance(column_names, Iterable), column_names
for cn in column_names:
assert isinstance(cn, str), cn
assert isinstance(columnName, str), columnName
assert isinstance(case_sensitive, bool), case_sensitive
return handle_configured_case_sensitivity(columnName, case_sensitive) in [handle_configured_case_sensitivity(c, case_sensitive) for c in column_names]
def list_filter_case_sensitivity(column_names: Iterable[str], filter: Iterable[str], case_sensitive: bool) -> List[str]:
assert isinstance(column_names, Iterable), column_names
for cn in column_names:
assert isinstance(cn, str), cn
assert isinstance(filter, Iterable), filter
for f in filter:
assert isinstance(f, str), f
assert isinstance(case_sensitive, bool), case_sensitive
filter_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in filter}
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) in filter_set]
def list_diff_case_sensitivity(column_names: Iterable[str], other: Iterable[str], case_sensitive: bool) -> List[str]:
assert isinstance(column_names, Iterable), column_names
for cn in column_names:
assert isinstance(cn, str), cn
assert isinstance(other, Iterable), filter
for o in other:
assert isinstance(o, str), o
assert isinstance(case_sensitive, bool), case_sensitive
other_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in other}
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) not in other_set]
def dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column:
"""
Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be
convertible to a number (e.g. string, int, long). The Spark timestamp type does not support
nanoseconds, so the last digit of the timestamp (1/10 of a microsecond) is lost.
{{{
df.select(col("ticks"), dotNetTicksToTimestamp("ticks").alias("timestamp")).show(false)
}}}
+------------------+--------------------------+
|ticks |timestamp |
+------------------+--------------------------+
|638155413748959318|2023-03-27 21:16:14.895931|
+------------------+--------------------------+
Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to
preserve the full precision of the tick timestamp.
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param tick_column: column with a tick value (str or Column)
:return: timestamp column
"""
if not _is_column_or_str(tick_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToTimestamp
return Column(func(_to_java_column(tick_column)))
def dotnet_ticks_to_unix_epoch(tick_column: Union[str, Column]) -> Column:
"""
Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be
convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
is preserved (1/10 of a microsecond).
Example:
{{{
df.select(col("ticks"), dotNetTicksToUnixEpoch("ticks").alias("timestamp")).show(false)
}}}
+------------------+--------------------+
|ticks |timestamp |
+------------------+--------------------+
|638155413748959318|1679944574.895931800|
+------------------+--------------------+
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param tick_column: column with a tick value (str or Column)
:return: Unix epoch column
"""
if not _is_column_or_str(tick_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpoch
return Column(func(_to_java_column(tick_column)))
def dotnet_ticks_to_unix_epoch_nanos(tick_column: Union[str, Column]) -> Column:
"""
Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be
convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
is preserved (1/10 of a microsecond).
Example:
{{{
df.select(col("ticks"), dotNetTicksToUnixEpoch("ticks").alias("timestamp")).show(false)
}}}
+------------------+-------------------+
|ticks |timestamp |
+------------------+-------------------+
|638155413748959318|1679944574895931800|
+------------------+-------------------+
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param tick_column: column with a tick value (str or Column)
:return: Unix epoch column
"""
if not _is_column_or_str(tick_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpochNanos
return Column(func(_to_java_column(tick_column)))
def timestamp_to_dotnet_ticks(timestamp_column: Union[str, Column]) -> Column:
"""
Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp.
The input column must be of TimestampType.
Example:
{{{
df.select(col("timestamp"), timestampToDotNetTicks("timestamp").alias("ticks")).show(false)
}}}
+--------------------------+------------------+
|timestamp |ticks |
+--------------------------+------------------+
|2023-03-27 21:16:14.895931|638155413748959310|
+--------------------------+------------------+
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param timestamp_column: column with a timestamp value
:return: tick value column
"""
if not _is_column_or_str(timestamp_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(timestamp_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").timestampToDotNetTicks
return Column(func(_to_java_column(timestamp_column)))
def unix_epoch_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:
"""
Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp.
The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal.
The input must not be of TimestampType, as that may be interpreted incorrectly.
Use `timestampToDotNetTicks` for TimestampType columns instead.
Example:
{{{
df.select(col("unix"), unixEpochToDotNetTicks("unix").alias("ticks")).show(false)
}}}
+-----------------------------+------------------+
|unix |ticks |
+-----------------------------+------------------+
|1679944574.895931234000000000|638155413748959312|
+-----------------------------+------------------+
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param unix_column: column with a unix epoch value
:return: tick value column
"""
if not _is_column_or_str(unix_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochToDotNetTicks
return Column(func(_to_java_column(unix_column)))
def unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:
"""
Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp.
The .Net ticks timestamp does not support the two lowest nanosecond digits,
so only a 1/10 of a microsecond is the smallest resolution.
The input column must represent a numerical unix epoch nanoseconds timestamp,
e.g. long, double, string or decimal.
Example:
{{{
df.select(col("unix_nanos"), unixEpochNanosToDotNetTicks("unix_nanos").alias("ticks")).show(false)
}}}
+-------------------+------------------+
|unix_nanos |ticks |
+-------------------+------------------+
|1679944574895931234|638155413748959312|
+-------------------+------------------+
https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
:param unix_column: column with a unix epoch value
:return: tick value column
"""
if not _is_column_or_str(unix_column):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}")
jvm = _get_jvm(SparkContext._active_spark_context)
func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochNanosToDotNetTicks
return Column(func(_to_java_column(unix_column)))
def count_null(e: "ColumnOrName") -> Column:
"""
Aggregate function: returns the number of items in a group that are not null.
Parameters
----------
col : :class:`~pyspark.sql.Column` or str target column to compute on.
Returns
-------
:class:`~pyspark.sql.Column`
column for computed results.
"""
if isinstance(e, str):
e = col(e)
if not _is_column(e):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(e)}")
return count(when(e.isNull(), lit(1)))
def histogram(self: DataFrame,
thresholds: List[Union[int, float]],
value_column: str,
*aggregate_columns: str) -> DataFrame:
if len(thresholds) == 0:
t = 'Int'
else:
t = type(thresholds[0])
if t == int:
t = 'Int'
elif t == float:
t = 'Double'
else:
raise ValueError('thresholds must be int or floats: {}'.format(t))
jvm = _get_jvm(self)
col = jvm.org.apache.spark.sql.functions.col
value_column = col(value_column)
aggregate_columns = [col(column) for column in aggregate_columns]
hist = jvm.uk.co.gresearch.spark.Histogram
jdf = hist.of(self._jdf, _to_seq(jvm, thresholds), value_column, _to_seq(jvm, aggregate_columns))
return DataFrame(jdf, self.session_or_ctx())
DataFrame.histogram = histogram
if has_connect:
ConnectDataFrame.histogram = histogram
class UnpersistHandle:
def __init__(self, handle):
self._handle = handle
def __call__(self, blocking: Optional[bool] = None):
if self._handle is not None:
if blocking is None:
self._handle.apply()
else:
self._handle.apply(blocking)
def unpersist_handle(self: SparkSession) -> UnpersistHandle:
jvm = _get_jvm(self)
handle = jvm.uk.co.gresearch.spark.UnpersistHandle()
return UnpersistHandle(handle)
SparkSession.unpersist_handle = unpersist_handle
def _get_sort_cols(df: DataFrame, order: Union[str, Column, List[Union[str, Column]]], ascending: Union[bool, List[bool]]):
if __version__.startswith('3.'):
# pyspark<4
return df._sort_cols([order], {'ascending': ascending})
# pyspark>=4
_cols = df._preapare_cols_for_sort(col, [order], {"ascending": ascending})
return df._jseq(_cols, _to_java_column)
def with_row_numbers(self: DataFrame,
storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK,
unpersist_handle: Optional[UnpersistHandle] = None,
row_number_column_name: str = "row_number",
order: Union[str, Column, List[Union[str, Column]]] = [],
ascending: Union[bool, List[bool]] = True) -> DataFrame:
jvm = _get_jvm(self)
jsl = self._sc._getJavaStorageLevel(storage_level)
juho = jvm.uk.co.gresearch.spark.UnpersistHandle
juh = unpersist_handle._handle if unpersist_handle else juho.Noop()
jcols = _get_sort_cols(self, order, ascending) if not isinstance(order, list) or order else jvm.PythonUtils.toSeq([])
row_numbers = jvm.uk.co.gresearch.spark.RowNumbers
jdf = row_numbers \
.withRowNumberColumnName(row_number_column_name) \
.withStorageLevel(jsl) \
.withUnpersistHandle(juh) \
.withOrderColumns(jcols) \
.of(self._jdf)
return DataFrame(jdf, self.session_or_ctx())
DataFrame.with_row_numbers = with_row_numbers
if has_connect:
ConnectDataFrame.with_row_numbers = with_row_numbers
def session(self: DataFrame) -> SparkSession:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx.sparkSession
def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx
DataFrame.session = session
DataFrame.session_or_ctx = session_or_ctx
if has_connect:
ConnectDataFrame.session = session
ConnectDataFrame.session_or_ctx = session_or_ctx
def set_description(description: Optional[str], if_not_set: bool = False):
if description is not None:
assert isinstance(description, str), description
assert isinstance(if_not_set, bool), if_not_set
context = SparkContext._active_spark_context
jvm = _get_jvm(context)
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
return spark_package.setJobDescription(description, if_not_set, context._jsc.sc())
@contextmanager
def job_description(description: str, if_not_set: bool = False):
"""
Adds a job description to all Spark jobs started within this context.
The current Job description is restored after leaving the context.
Usage example:
>>> from gresearch.spark import job_description
>>>
>>> with job_description("parquet file"):
... df = spark.read.parquet("data.parquet")
... count = df.count
With ``if_not_set = True``, the description is only set if no job description is set yet.
Any modification to the job description within the context is reverted on exit,
even if `if_not_set = True`.
:param description: job description
:param if_not_set: job description is only set if no description is set yet
"""
earlier = set_description(description, if_not_set)
try:
yield
finally:
set_description(earlier)
def append_description(extra_description: str, separator: str = " - "):
assert isinstance(extra_description, str), extra_description
assert isinstance(separator, str), separator
context = SparkContext._active_spark_context
jvm = _get_jvm(context)
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
return spark_package.appendJobDescription(extra_description, separator, context._jsc.sc())
@contextmanager
def append_job_description(extra_description: str, separator: str = " - "):
"""
Appends a job description to all Spark jobs started within this context.
The current Job description is extended by the separator and the extra description
on entering the context, and restored after leaving the context.
Usage example:
>>> from gresearch.spark import append_job_description
>>>
>>> with append_job_description("parquet file"):
... df = spark.read.parquet("data.parquet")
... with append_job_description("count"):
... count = df.count
Any modification to the job description within the context is reverted on exit.
:param extra_description: job description to be appended
:param separator: separator used when appending description
"""
earlier = append_description(extra_description, separator)
try:
yield
finally:
set_description(earlier)
def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str) -> str:
"""
Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.
:param spark: spark session or context
:param prefix: prefix string of temporary directory name
:return: absolute path of temporary directory
"""
jvm = _get_jvm(spark)
root_dir = jvm.org.apache.spark.SparkFiles.getRootDirectory()
return tempfile.mkdtemp(prefix=prefix, dir=root_dir)
SparkSession.create_temporary_dir = create_temporary_dir
SparkContext.create_temporary_dir = create_temporary_dir
if has_connect:
ConnectSparkSession.create_temporary_dir = create_temporary_dir
def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None:
if __version__.startswith('2.') or __version__.startswith('3.0.'):
raise NotImplementedError(f'Not supported for PySpark __version__')
for option in package_or_pip_option:
assert isinstance(option, str), option
# just here to assert JVM is accessible
_get_jvm(spark)
if isinstance(spark, SparkSession):
spark = spark.sparkContext
# create temporary directory for packages, inside a directory which will be deleted on spark application shutdown
id = f"spark-extension-pip-pkgs-{time.time()}"
dir = spark.create_temporary_dir(f"{id}-")
# install packages via pip install
# it is best to run pip as a separate process and not calling into module pip
# https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program
subprocess.check_call([sys.executable, '-m', 'pip', "install"] + list(package_or_pip_option) + ["--target", dir])
# zip packages and remove directory
zip = shutil.make_archive(dir, "zip", dir)
shutil.rmtree(dir)
# register zip file as archive, and add as python source
# once support for Spark 3.0 is dropped, replace with spark.addArchive()
spark._jsc.sc().addArchive(zip + "#" + id)
spark._python_includes.append(id)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), id))
SparkSession.install_pip_package = install_pip_package
SparkContext.install_pip_package = install_pip_package
if has_connect:
ConnectSparkSession.install_pip_package = install_pip_package
def install_poetry_project(spark: Union[SparkSession, SparkContext],
*project: str,
poetry_python: Optional[str] = None,
pip_args: Optional[List[str]] = None) -> None:
import logging
logger = logging.getLogger()
# spark.install_pip_dependency has this limitation, and it is used by this method
# and we want to fail quickly here
if __version__.startswith('2.') or __version__.startswith('3.0.'):
raise NotImplementedError(f'Not supported for PySpark __version__')
for p in project:
assert isinstance(p, str), p
if poetry_python is not None:
assert isinstance(poetry_python, str), poetry_python
if pip_args is not None:
for pa in pip_args:
assert isinstance(pa, str), pa
# just here to assert JVM is accessible
_get_jvm(spark)
if isinstance(spark, SparkSession):
spark = spark.sparkContext
if poetry_python is None:
poetry_python = sys.executable
if pip_args is None:
pip_args = []
def check_and_log_poetry(proc: subprocess.CompletedProcess) -> List[str]:
stdout = proc.stdout.decode('utf-8').splitlines(keepends=False)
for line in stdout:
logger.info(f"poetry: {line}")
stderr = proc.stderr.decode('utf-8').splitlines(keepends=False)
for line in stderr:
logger.error(f"poetry: {line}")
if proc.returncode != 0:
raise RuntimeError(f'Poetry process terminated with exit code {proc.returncode}')
return stdout
def build_wheel(project: Path) -> Path:
logger.info(f"Running poetry using {poetry_python}")
# make sure the virtual env for this project exists, otherwise we won't get to see the build whl file in stdout
proc = subprocess.run([
poetry_python, '-m', 'poetry',
'env', 'use',
'--directory', str(project.absolute()),
sys.executable
], capture_output=True)
check_and_log_poetry(proc)
# build the whl file
proc = subprocess.run([
poetry_python, '-m', 'poetry',
'build',
'--verbose',
'--no-interaction',
'--format', 'wheel',
'--directory', str(project.absolute())
], capture_output=True)
stdout = check_and_log_poetry(proc)
# first matching line is taken to extract whl file name
whl_pattern = "^ - Built (.*.whl)$"
for line in stdout:
if match := re.match(whl_pattern, line):
return project.joinpath('dist', match.group(1))
raise RuntimeError(f'Could not find wheel file name in poetry output, was looking for "{whl_pattern}"')
wheels = [build_wheel(Path(path)) for path in project]
# install wheels via pip
spark.install_pip_package(*[str(whl.absolute()) for whl in wheels] + pip_args)
SparkSession.install_poetry_project = install_poetry_project
SparkContext.install_poetry_project = install_poetry_project
if has_connect:
ConnectSparkSession.install_poetry_project = install_poetry_project
================================================
FILE: python/gresearch/spark/diff/__init__.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from typing import Optional, Dict, Mapping, Any, Callable, List, Tuple, Union, Iterable, overload
from py4j.java_gateway import JavaObject, JVMView
from pyspark.sql import DataFrame, Column
from pyspark.sql.functions import col, lit, when, concat, coalesce, array, struct
from pyspark.sql.types import DataType, StructField, ArrayType
from gresearch.spark import _get_jvm, _to_seq, _to_map, backticks, distinct_prefix_for, \
handle_configured_case_sensitivity, list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity, \
has_connect, _is_dataframe
from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator
try:
# There is a chance users use the Python code contained in the jvm package with Spark
# without ever pip installing the whl package and thus lacking dependencies like this
#from typing_extensions import deprecated
raise ImportError()
except ImportError:
from typing import TypeVar
_T = TypeVar("_T")
class deprecated:
def __init__(self, msg: str) -> None:
self.msg = msg
def __call__(self, func: _T) -> _T:
import warnings
def deprecated_func(*args, **kwargs):
warnings.warn(self.msg, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
return deprecated_func
class DiffMode(Enum):
ColumnByColumn = "ColumnByColumn"
SideBySide = "SideBySide"
LeftSide = "LeftSide"
RightSide = "RightSide"
# should be in sync with default defined in Java
Default = ColumnByColumn
def _to_java(self, jvm: JVMView) -> JavaObject:
return jvm.uk.co.gresearch.spark.diff.DiffMode.withNameOption(self.name).get()
@dataclass(frozen=True)
class DiffOptions:
"""
Configuration class for diffing Datasets.
:param diff_column: name of the diff column
:type diff_column: str
:param left_column_prefix: prefix of columns from the left Dataset
:type left_column_prefix: str
:param right_column_prefix: prefix of columns from the right Dataset
:type right_column_prefix: str
:param insert_diff_value: value in diff column for inserted rows
:type insert_diff_value: str
:param change_diff_value: value in diff column for changed rows
:type change_diff_value: str
:param delete_diff_value: value in diff column for deleted rows
:type delete_diff_value: str
:param nochange_diff_value: value in diff column for un-changed rows
:type nochange_diff_value: str
:param change_column: name of change column
:type change_column: str
:param diff_mode: diff mode
:type diff_mode: DiffMode
:param sparse_mode: sparse mode
:type sparse_mode: bool
"""
diff_column: str = 'diff'
left_column_prefix: str = 'left'
right_column_prefix: str = 'right'
insert_diff_value: str = 'I'
change_diff_value: str = 'C'
delete_diff_value: str = 'D'
nochange_diff_value: str = 'N'
change_column: Optional[str] = None
diff_mode: DiffMode = DiffMode.Default
sparse_mode: bool = False
default_comparator: DiffComparator = DefaultDiffComparator()
data_type_comparators: Dict[DataType, DiffComparator] = dataclasses.field(default_factory=lambda: dict())
column_name_comparators: Dict[str, DiffComparator] = dataclasses.field(default_factory=lambda: dict())
def with_diff_column(self, diff_column: str) -> 'DiffOptions':
"""
Fluent method to change the diff column name.
Returns a new immutable DiffOptions instance with the new diff column name.
:param diff_column: new diff column name
:type diff_column: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(diff_column, str), diff_column
return dataclasses.replace(self, diff_column=diff_column)
def with_left_column_prefix(self, left_column_prefix: str) -> 'DiffOptions':
"""
Fluent method to change the prefix of columns from the left Dataset.
Returns a new immutable DiffOptions instance with the new column prefix.
:param left_column_prefix: new column prefix
:type left_column_prefix: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(left_column_prefix, str), left_column_prefix
return dataclasses.replace(self, left_column_prefix=left_column_prefix)
def with_right_column_prefix(self, right_column_prefix: str) -> 'DiffOptions':
"""
Fluent method to change the prefix of columns from the right Dataset.
Returns a new immutable DiffOptions instance with the new column prefix.
:param right_column_prefix: new column prefix
:type right_column_prefix: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(right_column_prefix, str), right_column_prefix
return dataclasses.replace(self, right_column_prefix=right_column_prefix)
def with_insert_diff_value(self, insert_diff_value: str) -> 'DiffOptions':
"""
Fluent method to change the value of inserted rows in the diff column.
Returns a new immutable DiffOptions instance with the new diff value.
:param insert_diff_value: new diff value
:type insert_diff_value: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(insert_diff_value, str), insert_diff_value
return dataclasses.replace(self, insert_diff_value=insert_diff_value)
def with_change_diff_value(self, change_diff_value: str) -> 'DiffOptions':
"""
Fluent method to change the value of changed rows in the diff column.
Returns a new immutable DiffOptions instance with the new diff value.
:param change_diff_value: new diff column name
:type change_diff_value: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(change_diff_value, str), change_diff_value
return dataclasses.replace(self, change_diff_value=change_diff_value)
def with_delete_diff_value(self, delete_diff_value: str) -> 'DiffOptions':
"""
Fluent method to change the value of deleted rows in the diff column.
Returns a new immutable DiffOptions instance with the new diff value.
:param delete_diff_value: new diff column name
:type delete_diff_value: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(delete_diff_value, str), delete_diff_value
return dataclasses.replace(self, delete_diff_value=delete_diff_value)
def with_nochange_diff_value(self, nochange_diff_value: str) -> 'DiffOptions':
"""
Fluent method to change the value of un-changed rows in the diff column.
Returns a new immutable DiffOptions instance with the new diff value.
:param nochange_diff_value: new diff column name
:type nochange_diff_value: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(nochange_diff_value, str), nochange_diff_value
return dataclasses.replace(self, nochange_diff_value=nochange_diff_value)
def with_change_column(self, change_column: str) -> 'DiffOptions':
"""
Fluent method to change the change column name.
Returns a new immutable DiffOptions instance with the new change column name.
:param change_column: new change column name
:type change_column: str
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(change_column, str), change_column
return dataclasses.replace(self, change_column=change_column)
def without_change_column(self) -> 'DiffOptions':
"""
Fluent method to remove change column.
Returns a new immutable DiffOptions instance without a change column.
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
return dataclasses.replace(self, change_column=None)
def with_diff_mode(self, diff_mode: DiffMode) -> 'DiffOptions':
"""
Fluent method to change the diff mode.
Returns a new immutable DiffOptions instance with the new diff mode.
:param diff_mode: new diff mode
:type diff_mode: DiffMode
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(diff_mode, DiffMode), diff_mode
return dataclasses.replace(self, diff_mode=diff_mode)
def with_sparse_mode(self, sparse_mode: bool) -> 'DiffOptions':
"""
Fluent method to change the sparse mode.
Returns a new immutable DiffOptions instance with the new sparse mode.
:param sparse: new sparse mode
:type sparse: bool
:return: new immutable DiffOptions instance
:rtype: DiffOptions
"""
assert isinstance(sparse_mode, bool), sparse_mode
return dataclasses.replace(self, sparse_mode=sparse_mode)
def with_default_comparator(self, comparator: DiffComparator) -> 'DiffOptions':
assert isinstance(comparator, DiffComparator), comparator
return dataclasses.replace(self, default_comparator=comparator)
def with_data_type_comparator(self, comparator: DiffComparator, *data_type: DataType) -> 'DiffOptions':
assert isinstance(comparator, DiffComparator), comparator
for dt in data_type:
assert isinstance(dt, DataType), dt
existing_data_types = {dt.simpleString() for dt in data_type if dt in self.data_type_comparators.keys()}
if existing_data_types:
existing_data_types = sorted(list(existing_data_types))
raise ValueError(f'A comparator for data type{"s" if len(existing_data_types) > 1 else ""} '
f'{", ".join(existing_data_types)} exists already.')
data_type_comparators = self.data_type_comparators.copy()
data_type_comparators.update({dt: comparator for dt in data_type})
return dataclasses.replace(self, data_type_comparators=data_type_comparators)
def with_column_name_comparator(self, comparator: DiffComparator, *column_name: str) -> 'DiffOptions':
assert isinstance(comparator, DiffComparator), comparator
for cn in column_name:
assert isinstance(cn, str), cn
existing_column_names = {cn for cn in column_name if cn in self.column_name_comparators.keys()}
if existing_column_names:
existing_column_names = sorted(list(existing_column_names))
raise ValueError(f'A comparator for column name{"s" if len(existing_column_names) > 1 else ""} '
f'{", ".join(existing_column_names)} exists already.')
column_name_comparators = self.column_name_comparators.copy()
column_name_comparators.update({dt: comparator for dt in column_name})
return dataclasses.replace(self, column_name_comparators=column_name_comparators)
def comparator_for(self, column: StructField) -> DiffComparator:
assert isinstance(column, StructField), column
cmp = self.column_name_comparators.get(column.name)
if cmp is None:
cmp = self.data_type_comparators.get(column.dataType)
if cmp is None:
cmp = self.default_comparator
return cmp
class Differ:
"""
Differ class to diff two Datasets. See Differ.of(…) for details.
:param options: options for the diffing process
:type options: DiffOptions
"""
def __init__(self, options: DiffOptions = None):
self._options = options or DiffOptions()
@overload
def diff(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ...
@overload
def diff(self, left: DataFrame, right: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
def diff(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between the two DataFrames.
Both DataFrames must contain the same set of column names and data types.
The order of columns in the two DataFrames is not important as columns are compared based on the
name, not the position.
Optional id columns are used to uniquely identify rows to compare. If values in any non-id
column are differing between the two DataFrames, then that row is marked as `"C"`hange
and `"N"`o-change otherwise. Rows of the right DataFrame, that do not exist in the left DataFrame
(w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left DataFrame,
that do not exist in the right DataFrame are marked as `"D"`elete.
If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
will appear, as all changes will exist as respective `"D"`elete and `"I"`nsert.
Values in optional ignore columns are not compared but included in the output DataFrame.
The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
`"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
.. code-block:: python
df1 = spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"])
df2 = spark.createDataFrame([(1, "one"), (2, "Two"), (4, "four")], ["id", "value"])
differ.diff(df1, df2).show()
// output:
// +----+---+-----+
// |diff| id|value|
// +----+---+-----+
// | N| 1| one|
// | D| 2| two|
// | I| 2| Two|
// | D| 3|three|
// | I| 4| four|
// +----+---+-----+
differ.diff(df1, df2, "id").show()
// output:
// +----+---+----------+-----------+
// |diff| id|left_value|right_value|
// +----+---+----------+-----------+
// | N| 1| one| one|
// | C| 2| two| Two|
// | D| 3| three| null|
// | I| 4| null| four|
// +----+---+----------+-----------+
The id columns are in order as given to the method. If no id columns are given then all
columns of this DataFrame are id columns and appear in the same order. The remaining non-id
columns are in the order of this DataFrame.
:param left: left DataFrame
:type left: DataFrame
:param right: right DataFrame
:type right: DataFrame
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type *id_or_ignore_columns: str | Iterable[str]
:return: the diff DataFrame
:rtype DataFrame
"""
assert _is_dataframe(left), left
assert _is_dataframe(right), right
assert isinstance(id_or_ignore_columns, (str, Iterable)), id_or_ignore_columns
if len(id_or_ignore_columns) == 2 and all(isinstance(lst, Iterable) and not isinstance(lst, str) for lst in id_or_ignore_columns):
id_columns, ignore_columns = id_or_ignore_columns
if any(not isinstance(id, str) for id in id_columns):
raise ValueError(f"The id_columns must all be strings: {', '.join(type(id).__name__ for id in id_columns)}")
if any(not isinstance(ignore, str) for ignore in ignore_columns):
raise ValueError(f"The ignore_columns must all be strings: {', '.join(type(ignore).__name__ for ignore in ignore_columns)}")
elif all(isinstance(lst, str) for lst in id_or_ignore_columns):
id_columns, ignore_columns = (id_or_ignore_columns, [])
else:
raise ValueError(f"The id_or_ignore_columns argument must either all be strings or exactly two iterables of strings: {', '.join(type(e).__name__ for e in id_or_ignore_columns)}")
return self._do_diff(left, right, id_columns, ignore_columns)
@staticmethod
def _columns_of_side(df: DataFrame, id_columns: List[str], side_prefix: str) -> List[Column]:
prefix = side_prefix + '_'
return [col(c) if c in id_columns else col(c).alias(c.replace(prefix, ""))
for c in df.columns if c in id_columns or c.startswith(side_prefix)]
@overload
def diffwith(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ...
@overload
def diffwith(self, left: DataFrame, right: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
def diffwith(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between the two DataFrames
as tuples of type `(String, Row, Row)`.
See `diff(left: DataFrame, right: DataFrame, *id_columns: str)`.
:param left: left DataFrame
:type left: DataFrame
:param right: right DataFrame
:type right: DataFrame
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type id_or_ignore_columns: str
:return: the diff DataFrame
:rtype DataFrame
"""
assert _is_dataframe(left), left
assert _is_dataframe(right), right
assert isinstance(id_or_ignore_columns, (str, Iterable)), id_or_ignore_columns
if len(id_or_ignore_columns) == 2 and all([isinstance(lst, Iterable) for lst in id_or_ignore_columns]):
id_columns, ignore_columns = id_or_ignore_columns
if any(not isinstance(id, str) for id in id_columns):
raise ValueError(f"The id_columns must all be strings: {', '.join(type(id).__name__ for id in id_columns)}")
if any(not isinstance(ignore, str) for ignore in ignore_columns):
raise ValueError(f"The ignore_columns must all be strings: {', '.join(type(ignore).__name__ for ignore in ignore_columns)}")
elif all(isinstance(lst, str) for lst in id_or_ignore_columns):
id_columns, ignore_columns = (id_or_ignore_columns, [])
else:
raise ValueError(f"The id_or_ignore_columns argument must either all be strings or exactly two iterables of strings: {', '.join(type(e).__name__ for e in id_or_ignore_columns)}")
diff = self._do_diff(left, right, id_columns, ignore_columns)
left_columns = self._columns_of_side(diff, id_columns, self._options.left_column_prefix)
right_columns = self._columns_of_side(diff, id_columns, self._options.right_column_prefix)
diff_column = col(self._options.diff_column)
left_struct = when(diff_column == self._options.insert_diff_value, lit(None)) \
.otherwise(struct(*left_columns)) \
.alias(self._options.left_column_prefix)
right_struct = when(diff_column == self._options.delete_diff_value, lit(None)) \
.otherwise(struct(*right_columns)) \
.alias(self._options.right_column_prefix)
return diff.select(diff_column, left_struct, right_struct)
def _check_schema(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str], case_sensitive: bool):
def require(result: bool, message: str) -> None:
if not result:
raise ValueError(message)
require(
len(left.columns) == len(set(left.columns)) and len(right.columns) == len(set(right.columns)),
f"The datasets have duplicate columns.\n" +
f"Left column names: {', '.join(left.columns)}\n" +
f"Right column names: {', '.join(right.columns)}")
left_non_ignored = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive)
right_non_ignored = list_diff_case_sensitivity(right.columns, ignore_columns, case_sensitive)
except_ignored_columns_msg = ' except ignored columns' if ignore_columns else ''
require(
len(left_non_ignored) == len(right_non_ignored),
"The number of columns doesn't match.\n" +
f"Left column names{except_ignored_columns_msg} ({len(left_non_ignored)}): {', '.join(left_non_ignored)}\n" +
f"Right column names{except_ignored_columns_msg} ({len(right_non_ignored)}): {', '.join(right_non_ignored)}"
)
require(len(left_non_ignored) > 0, f"The schema{except_ignored_columns_msg} must not be empty")
# column types must match but we ignore the nullability of columns
left_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType
for field in left.schema.fields
if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)}
right_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType
for field in right.schema.fields
if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)}
left_extra_schema = set(left_fields.items()) - set(right_fields.items())
right_extra_schema = set(right_fields.items()) - set(left_fields.items())
require(
len(left_extra_schema) == 0 and len(right_extra_schema) == 0,
"The datasets do not have the same schema.\n" +
f"Left extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(left_extra_schema))])}\n" +
f"Right extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(right_extra_schema))])}")
columns = left_non_ignored
pk_columns = id_columns or columns
non_pk_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive)
missing_id_columns = list_diff_case_sensitivity(pk_columns, columns, case_sensitive)
require(
len(missing_id_columns) == 0,
f"Some id columns do not exist: {', '.join(missing_id_columns)} missing among {', '.join(columns)}"
)
missing_ignore_columns = list_diff_case_sensitivity(ignore_columns, left.columns + right.columns, case_sensitive)
require(
len(missing_ignore_columns) == 0,
f"Some ignore columns do not exist: {', '.join(missing_ignore_columns)} " +
f"missing among {', '.join(sorted(list(set(left_non_ignored + right_non_ignored))))}"
)
require(
not list_contains_case_sensitivity(pk_columns, self._options.diff_column, case_sensitive),
f"The id columns must not contain the diff column name '{self._options.diff_column}': {', '.join(pk_columns)}"
)
require(
self._options.change_column is None or not list_contains_case_sensitivity(pk_columns, self._options.change_column, case_sensitive),
f"The id columns must not contain the change column name '{self._options.change_column}': {', '.join(pk_columns)}"
)
diff_value_columns = self._get_diff_value_columns(pk_columns, non_pk_columns, left, right, ignore_columns, case_sensitive)
diff_value_columns = {n for n, t in diff_value_columns}
if self._options.diff_mode in [DiffMode.LeftSide, DiffMode.RightSide]:
require(
not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive),
f"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} " +
f"non-id columns must not contain the diff column name '{self._options.diff_column}': " +
f"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}"
)
require(
self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive),
f"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} " +
f"non-id columns must not contain the change column name '{self._options.change_column}': " +
f"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}"
)
else:
require(
not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive),
f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " +
f"together with these non-id columns must not produce the diff column name '{self._options.diff_column}': " +
f"{', '.join(non_pk_columns)}"
)
require(
self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive),
f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " +
f"together with these non-id columns must not produce the change column name '{self._options.change_column}': " +
f"{', '.join(non_pk_columns)}"
)
require(
all(not list_contains_case_sensitivity(pk_columns, c, case_sensitive) for c in diff_value_columns),
f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " +
f"together with these non-id columns must not produce any id column name '{', '.join(pk_columns)}': " +
f"{', '.join(non_pk_columns)}"
)
def _get_change_column(self,
exists_column_name: str,
value_columns_with_comparator: List[Tuple[str, DiffComparator]],
left: DataFrame,
right: DataFrame) -> Optional[Column]:
if self._options.change_column is None:
return None
if not self._options.change_column:
return array().cast(ArrayType(StringType, containsNull = false)).alias(self._options.change_column)
return when(left[exists_column_name].isNull() | right[exists_column_name].isNull(), lit(None)) \
.otherwise(
concat(*[when(cmp.equiv(left[c], right[c]), array()).otherwise(array(lit(c)))
for (c, cmp) in value_columns_with_comparator])) \
.alias(self._options.change_column)
def _do_diff(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame:
case_sensitive = left.session().conf.get("spark.sql.caseSensitive") == "true"
self._check_schema(left, right, id_columns, ignore_columns, case_sensitive)
columns = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive)
pk_columns = id_columns or columns
value_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive)
value_struct_fields = {f.name: f for f in left.schema.fields}
value_columns_with_comparator = [(c, self._options.comparator_for(value_struct_fields[c])) for c in value_columns]
exists_column_name = distinct_prefix_for(left.columns) + "exists"
left_with_exists = left.withColumn(exists_column_name, lit(1))
right_with_exists = right.withColumn(exists_column_name, lit(1))
join_condition = reduce(lambda l, r: l & r,
[left_with_exists[c].eqNullSafe(right_with_exists[c])
for c in pk_columns])
un_changed = reduce(lambda l, r: l & r,
[cmp.equiv(left_with_exists[c], right_with_exists[c])
for (c, cmp) in value_columns_with_comparator],
lit(True))
change_condition = ~un_changed
diff_action_column = \
when(left_with_exists[exists_column_name].isNull(), lit(self._options.insert_diff_value)) \
.when(right_with_exists[exists_column_name].isNull(), lit(self._options.delete_diff_value)) \
.when(change_condition, lit(self._options.change_diff_value)) \
.otherwise(lit(self._options.nochange_diff_value)) \
.alias(self._options.diff_column)
diff_columns = [c[1] for c in self._get_diff_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive)]
# turn this column into a list of one or none column so we can easily concat it below with diffActionColumn and diffColumns
change_column = self._get_change_column(exists_column_name, value_columns_with_comparator, left_with_exists, right_with_exists)
change_columns = [change_column] if change_column is not None else []
return left_with_exists \
.join(right_with_exists, join_condition, "fullouter") \
.select(*([diff_action_column] + change_columns + diff_columns))
def _get_diff_id_columns(self, pk_columns: List[str],
left: DataFrame,
right: DataFrame) -> List[Tuple[str, Column]]:
return [(c, coalesce(left[c], right[c]).alias(c)) for c in pk_columns]
def _get_diff_value_columns(self, pk_columns: List[str],
value_columns: List[str],
left: DataFrame,
right: DataFrame,
ignore_columns: List[str],
case_sensitive: bool) -> List[Tuple[str, Column]]:
left_value_columns = list_filter_case_sensitivity(left.columns, value_columns, case_sensitive)
right_value_columns = list_filter_case_sensitivity(right.columns, value_columns, case_sensitive)
left_non_pk_columns = list_diff_case_sensitivity(left.columns, pk_columns, case_sensitive)
right_non_pk_columns = list_diff_case_sensitivity(right.columns, pk_columns, case_sensitive)
left_ignored_columns = list_filter_case_sensitivity(left.columns, ignore_columns, case_sensitive)
right_ignored_columns = list_filter_case_sensitivity(right.columns, ignore_columns, case_sensitive)
left_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), left[c]) if self._options.sparse_mode else left[c]) for c in left_non_pk_columns}
right_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), right[c]) if self._options.sparse_mode else right[c]) for c in right_non_pk_columns}
def alias(prefix: Optional[str], values: Dict[str, Tuple[str, Column]]) -> Callable[[str], Tuple[str, Column]]:
def func(name: str) -> (str, Column):
name, column = values[handle_configured_case_sensitivity(name, case_sensitive)]
alias = name if prefix is None else f'{prefix}_{name}'
return alias, column.alias(alias)
return func
def alias_left(name: str) -> (str, Column):
return alias(self._options.left_column_prefix, left_values)(name)
def alias_right(name: str) -> (str, Column):
return alias(self._options.right_column_prefix, right_values)(name)
prefixed_left_ignored_columns = [alias_left(c) for c in left_ignored_columns]
prefixed_right_ignored_columns = [alias_right(c) for c in right_ignored_columns]
if self._options.diff_mode == DiffMode.ColumnByColumn:
non_id_columns = \
[c for vc in value_columns for c in [alias_left(vc), alias_right(vc)]] + \
[c for ic in ignore_columns for c in (
([alias_left(ic)] if list_contains_case_sensitivity(left_ignored_columns, ic, case_sensitive) else []) +
([alias_right(ic)] if list_contains_case_sensitivity(right_ignored_columns, ic, case_sensitive) else [])
)]
elif self._options.diff_mode == DiffMode.SideBySide:
non_id_columns = \
[alias_left(c) for c in left_value_columns] + prefixed_left_ignored_columns + \
[alias_right(c) for c in right_value_columns] + prefixed_right_ignored_columns
elif self._options.diff_mode == DiffMode.LeftSide:
non_id_columns = \
[alias(None, left_values)(c) for c in value_columns] +\
[alias(None, left_values)(c) for c in left_ignored_columns]
elif self._options.diff_mode == DiffMode.RightSide:
non_id_columns = \
[alias(None, right_values)(c) for c in value_columns] + \
[alias(None, right_values)(c) for c in right_ignored_columns]
else:
raise RuntimeError(f'Unsupported diff mode: {self._options.diff_mode}')
return non_id_columns
def _get_diff_columns(self, pk_columns: List[str],
value_columns: List[str],
left: DataFrame,
right: DataFrame,
ignore_columns: List[str],
case_sensitive: bool) -> List[Tuple[str, Column]]:
return self._get_diff_id_columns(pk_columns, left, right) + \
self._get_diff_value_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive)
@overload
def diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame: ...
@overload
def diff(self: DataFrame, other: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@overload
def diff(self: DataFrame, other: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...
@overload
def diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...
@overload
def diff(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@overload
def diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...
def diff(self: DataFrame, other: DataFrame, *options_or_id_or_ignore_columns: Union[DiffOptions, str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between this and the other DataFrame.
Both DataFrames must contain the same set of column names and data types.
The order of columns in the two DataFrames is not important as one column is compared to the
column with the same name of the other DataFrame, not the column with the same position.
Optional options allow for customizing diffing behaviour and diff result schema.
Optional id columns are used to uniquely identify rows to compare. If values in any non-id
column are differing between this and the other DataFrame, then that row is marked as `"C"`hange
and `"N"`o-change otherwise. Rows of the other DataFrame, that do not exist in this DataFrame
(w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of this DataFrame, that
do not exist in the other DataFrame are marked as `"D"`elete.
If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
will appear, as all changes will exist as respective `"D"`elete and `"I"`nsert.
Values in optional ignore columns are not compared but included in the output DataFrame.
The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
`"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
.. code-block:: python
df1 = spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"])
df2 = spark.createDataFrame([(1, "one"), (2, "Two"), (4, "four")], ["id", "value"])
df1.diff(df2).show()
// output:
// +----+---+-----+
// |diff| id|value|
// +----+---+-----+
// | N| 1| one|
// | D| 2| two|
// | I| 2| Two|
// | D| 3|three|
// | I| 4| four|
// +----+---+-----+
df1.diff(df2, "id").show()
// output:
// +----+---+----------+-----------+
// |diff| id|left_value|right_value|
// +----+---+----------+-----------+
// | N| 1| one| one|
// | C| 2| two| Two|
// | D| 3| three| null|
// | I| 4| null| four|
// +----+---+----------+-----------+
The id columns are in order as given to the method. If no id columns are given then all
columns of this DataFrame are id columns and appear in the same order. The remaining non-id
columns are in the order of this DataFrame.
:param other: right DataFrame
:type other: DataFrame
:param options: optional diff options
:type options: DiffOptions
:param id_columns: id columns
:type id_columns: str
:param ignore_columns: optional ignored columns
:type ignore_columns: str
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type id_or_ignore_columns: str
:return: the diff DataFrame
:rtype DataFrame
"""
if any(isinstance(i, DiffOptions) for i in options_or_id_or_ignore_columns):
options = options_or_id_or_ignore_columns[0]
if not isinstance(options, DiffOptions):
raise ValueError("Diff options must be given as second argument")
id_or_ignore_columns = options_or_id_or_ignore_columns[1:]
return Differ(options).diff(self, other, *id_or_ignore_columns)
id_or_ignore_columns = options_or_id_or_ignore_columns
return Differ().diff(self, other, *id_or_ignore_columns)
@overload
def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame: ...
@overload
def diffwith(self: DataFrame, other: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@overload
def diffwith(self: DataFrame, other: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...
@overload
def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...
@overload
def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@overload
def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...
def diffwith(self: DataFrame, other: DataFrame, *options_or_id_or_ignore_columns: Union[DiffOptions, str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between the two DataFrames
as tuples of type `(String, Row, Row)`.
See `diff(left: DataFrame, right: DataFrame, *options_or_id_or_ignore_columns: str)`.
:param left: left DataFrame
:type left: DataFrame
:param right: right DataFrame
:type right: DataFrame
:param options: diff options
:type options: DiffOptions
:param id_columns: id columns
:type id_columns: str
:param ignore_columns: optional ignored columns
:type ignore_columns: str
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type id_or_ignore_columns: str
:return: the diff DataFrame
:rtype DataFrame
"""
if any(isinstance(i, DiffOptions) for i in options_or_id_or_ignore_columns):
options = options_or_id_or_ignore_columns[0]
if not isinstance(options, DiffOptions):
raise ValueError("Diff options must be given as second argument")
id_or_ignore_columns = options_or_id_or_ignore_columns[1:]
return Differ(options).diffwith(self, other, *id_or_ignore_columns)
id_or_ignore_columns = options_or_id_or_ignore_columns
return Differ().diffwith(self, other, *id_or_ignore_columns)
@overload
def diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...
@overload
def diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@deprecated("Use diff with identical arguments instead")
def diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between this and the other DataFrame.
See `diff(other: DataFrame, *id_columns: str)`.
The schema of the returned DataFrame can be configured by the given `DiffOptions`.
:param other: right DataFrame
:type other: DataFrame
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type id_or_ignore_columns: str
:param options: diff options
:type options: DiffOptions
:return: the diff DataFrame
:rtype DataFrame
"""
return Differ(options).diff(self, other, *id_or_ignore_columns)
@overload
def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...
@overload
def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...
@deprecated("Use diffwith with identical arguments instead")
def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:
"""
Returns a new DataFrame that contains the differences between the two DataFrames
as tuples of type `(String, Row, Row)`.
See `diff(left: DataFrame, right: DataFrame, *id_columns: str)`.
The schema of the returned DataFrame can be configured by the given `DiffOptions`.
:param other: right DataFrame
:type other: DataFrame
:param options: diff options
:type options: DiffOptions
:param id_or_ignore_columns: either id column names or two lists of column names,
first the id column names, second the ignore column names
:type id_or_ignore_columns: str
:return: the diff DataFrame
:rtype DataFrame
"""
return Differ(options).diffwith(self, other, *id_or_ignore_columns)
DataFrame.diff = diff
DataFrame.diffwith = diffwith
DataFrame.diff_with_options = diff_with_options
DataFrame.diffwith_with_options = diffwith_with_options
if has_connect:
from gresearch.spark import ConnectDataFrame
ConnectDataFrame.diff = diff
ConnectDataFrame.diffwith = diffwith
ConnectDataFrame.diff_with_options = diff_with_options
ConnectDataFrame.diffwith_with_options = diffwith_with_options
================================================
FILE: python/gresearch/spark/diff/comparator/__init__.py
================================================
# Copyright 2022 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
from dataclasses import dataclass
from py4j.java_gateway import JVMView, JavaObject
from pyspark.sql import Column
from pyspark.sql.functions import abs, greatest, lit
from pyspark.sql.types import DataType
from gresearch.spark import _is_column
class DiffComparator(abc.ABC):
@abc.abstractmethod
def equiv(self, left: Column, right: Column) -> Column:
pass
class DiffComparators:
@staticmethod
def default() -> 'DefaultDiffComparator':
return DefaultDiffComparator()
@staticmethod
def nullSafeEqual() -> 'NullSafeEqualDiffComparator':
return NullSafeEqualDiffComparator()
@staticmethod
def epsilon(epsilon: float) -> 'EpsilonDiffComparator':
assert isinstance(epsilon, float), epsilon
return EpsilonDiffComparator(epsilon)
@staticmethod
def string(whitespace_agnostic: bool = True) -> 'StringDiffComparator':
assert isinstance(whitespace_agnostic, bool), whitespace_agnostic
return StringDiffComparator(whitespace_agnostic)
@staticmethod
def duration(duration: str) -> 'DurationDiffComparator':
assert isinstance(duration, str), duration
return DurationDiffComparator(duration)
@staticmethod
def map(key_type: DataType, value_type: DataType, key_order_sensitive: bool = False) -> 'MapDiffComparator':
assert isinstance(key_type, DataType), key_type
assert isinstance(value_type, DataType), value_type
assert isinstance(key_order_sensitive, bool), key_order_sensitive
return MapDiffComparator(key_type, value_type, key_order_sensitive)
class NullSafeEqualDiffComparator(DiffComparator):
def equiv(self, left: Column, right: Column) -> Column:
assert _is_column(left), left
assert _is_column(right), right
return left.eqNullSafe(right)
class DefaultDiffComparator(NullSafeEqualDiffComparator):
# for testing only
def _to_java(self, jvm: JVMView) -> JavaObject:
return jvm.uk.co.gresearch.spark.diff.DiffComparators.default()
@dataclass(frozen=True)
class EpsilonDiffComparator(DiffComparator):
epsilon: float
relative: bool = True
inclusive: bool = True
def as_relative(self) -> 'EpsilonDiffComparator':
return dataclasses.replace(self, relative=True)
def as_absolute(self) -> 'EpsilonDiffComparator':
return dataclasses.replace(self, relative=False)
def as_inclusive(self) -> 'EpsilonDiffComparator':
return dataclasses.replace(self, inclusive=True)
def as_exclusive(self) -> 'EpsilonDiffComparator':
return dataclasses.replace(self, inclusive=False)
def equiv(self, left: Column, right: Column) -> Column:
assert _is_column(left), left
assert _is_column(right), right
threshold = greatest(abs(left), abs(right)) * self.epsilon if self.relative else lit(self.epsilon)
def inclusive_epsilon(diff: Column) -> Column:
return diff.__le__(threshold)
def exclusive_epsilon(diff: Column) -> Column:
return diff.__lt__(threshold)
in_epsilon = inclusive_epsilon if self.inclusive else exclusive_epsilon
return left.isNull() & right.isNull() | left.isNotNull() & right.isNotNull() & in_epsilon(abs(left - right))
@dataclass(frozen=True)
class StringDiffComparator(DiffComparator):
whitespace_agnostic: bool
def equiv(self, left: Column, right: Column) -> Column:
assert _is_column(left), left
assert _is_column(right), right
return left.eqNullSafe(right)
@dataclass(frozen=True)
class DurationDiffComparator(DiffComparator):
duration: str
inclusive: bool = True
def as_inclusive(self) -> 'DurationDiffComparator':
return dataclasses.replace(self, inclusive=True)
def as_exclusive(self) -> 'DurationDiffComparator':
return dataclasses.replace(self, inclusive=False)
def equiv(self, left: Column, right: Column) -> Column:
assert _is_column(left), left
assert _is_column(right), right
return left.eqNullSafe(right)
@dataclass(frozen=True)
class MapDiffComparator(DiffComparator):
key_type: DataType
value_type: DataType
key_order_sensitive: bool
def equiv(self, left: Column, right: Column) -> Column:
assert _is_column(left), left
assert _is_column(right), right
return left.eqNullSafe(right)
================================================
FILE: python/gresearch/spark/parquet/__init__.py
================================================
# Copyright 2023 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from py4j.java_gateway import JavaObject
from pyspark.sql import DataFrameReader, DataFrame
from gresearch.spark import _get_jvm, _to_seq
try:
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
has_connect = True
except ImportError:
has_connect = False
def _jreader(reader: DataFrameReader) -> JavaObject:
jvm = _get_jvm(reader)
return jvm.uk.co.gresearch.spark.parquet.__getattr__("package$").__getattr__("MODULE$").ExtendedDataFrameReader(reader._jreader)
def parquet_metadata(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:
"""
Read the metadata of Parquet files into a Dataframe.
The returned DataFrame has as many partitions as specified via `parallelism`.
If not specified, there are as many partitions as there are Parquet files,
at most `spark.sparkContext.defaultParallelism` partitions.
This provides the following per-file information:
- filename (string): The file name
- blocks (int): Number of blocks / RowGroups in the Parquet file
- compressedBytes (long): Number of compressed bytes of all blocks
- uncompressedBytes (long): Number of uncompressed bytes of all blocks
- rows (long): Number of rows in the file
- columns (int): Number of rows in the file
- values (long): Number of values in the file
- nulls (long): Number of null values in the file
- createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
- schema (string): The schema
- encryption (string): The encryption
- keyValues (string-to-string map): Key-value data of the file
:param self: a Spark DataFrameReader
:param paths: paths one or more paths to Parquet files or directories
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetMetadata(_to_seq(jvm, list(paths)))
else:
jdf = _jreader(self).parquetMetadata(parallelism, _to_seq(jvm, list(paths)))
return DataFrame(jdf, self._spark)
def parquet_schema(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:
"""
Read the schema of Parquet files into a Dataframe.
The returned DataFrame has as many partitions as specified via `parallelism`.
If not specified, there are as many partitions as there are Parquet files,
at most `spark.sparkContext.defaultParallelism` partitions.
This provides the following per-file information:
- filename (string): The Parquet file name
- columnName (string): The column name
- columnPath (string array): The column path
- repetition (string): The repetition
- type (string): The data type
- length (int): The length of the type
- originalType (string): The original type
- isPrimitive (boolean: True if type is primitive
- primitiveType (string: The primitive type
- primitiveOrder (string: The order of the primitive type
- maxDefinitionLevel (int): The max definition level
- maxRepetitionLevel (int): The max repetition level
:param self: a Spark DataFrameReader
:param paths: paths one or more paths to Parquet files or directories
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetSchema(_to_seq(jvm, list(paths)))
else:
jdf = _jreader(self).parquetSchema(parallelism, _to_seq(jvm, list(paths)))
return DataFrame(jdf, self._spark)
def parquet_blocks(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:
"""
Read the metadata of Parquet blocks into a Dataframe.
The returned DataFrame has as many partitions as specified via `parallelism`.
If not specified, there are as many partitions as there are Parquet files,
at most `spark.sparkContext.defaultParallelism` partitions.
This provides the following per-block information:
- filename (string): The file name
- block (int): Block / RowGroup number starting at 1
- blockStart (long): Start position of the block in the Parquet file
- compressedBytes (long): Number of compressed bytes in block
- uncompressedBytes (long): Number of uncompressed bytes in block
- rows (long): Number of rows in block
- columns (int): Number of columns in block
- values (long): Number of values in block
- nulls (long): Number of null values in block
:param self: a Spark DataFrameReader
:param paths: paths one or more paths to Parquet files or directories
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetBlocks(_to_seq(jvm, list(paths)))
else:
jdf = _jreader(self).parquetBlocks(parallelism, _to_seq(jvm, list(paths)))
return DataFrame(jdf, self._spark)
def parquet_block_columns(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:
"""
Read the metadata of Parquet block columns into a Dataframe.
The returned DataFrame has as many partitions as specified via `parallelism`.
If not specified, there are as many partitions as there are Parquet files,
at most `spark.sparkContext.defaultParallelism` partitions.
This provides the following per-block-column information:
- filename (string): The file name
- block (int): Block / RowGroup number starting at 1
- column (array): Block / RowGroup column name
- codec (string): The coded used to compress the block column values
- type (string): The data type of the block column
- encodings (array): Encodings of the block column
- minValue (string): Minimum value of this column in this block
- maxValue (string): Maximum value of this column in this block
- columnStart (long): Start position of the block column in the Parquet file
- compressedBytes (long): Number of compressed bytes of this block column
- uncompressedBytes (long): Number of uncompressed bytes of this block column
- values (long): Number of values in this block column
- nulls (long): Number of null values in this block column
:param self: a Spark DataFrameReader
:param paths: paths one or more paths to Parquet files or directories
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetBlockColumns(_to_seq(jvm, list(paths)))
else:
jdf = _jreader(self).parquetBlockColumns(parallelism, _to_seq(jvm, list(paths)))
return DataFrame(jdf, self._spark)
def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:
"""
Read the metadata of how Spark partitions Parquet files into a Dataframe.
The returned DataFrame has as many partitions as specified via `parallelism`.
If not specified, there are as many partitions as there are Parquet files,
at most `spark.sparkContext.defaultParallelism` partitions.
This provides the following per-partition information:
- partition (int): The Spark partition id
- partitionStart (long): The start position of the partition
- partitionEnd (long): The end position of the partition
- partitionLength (long): The length of the partition
- blocks (int): The number of Parquet blocks / RowGroups in this partition
- compressedBytes (long): The number of compressed bytes in this partition
- uncompressedBytes (long): The number of uncompressed bytes in this partition
- rows (long): The number of rows in this partition
- columns (int): The number of columns in this partition
- values (long): The number of values in this partition
- nulls (long): The number of null values in this partition
- filename (string): The Parquet file name
- fileLength (long): The length of the Parquet file
:param self: a Spark DataFrameReader
:param paths: paths one or more paths to Parquet files or directories
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetPartitions(_to_seq(jvm, list(paths)))
else:
jdf = _jreader(self).parquetPartitions(parallelism, _to_seq(jvm, list(paths)))
return DataFrame(jdf, self._spark)
DataFrameReader.parquet_metadata = parquet_metadata
DataFrameReader.parquet_schema = parquet_schema
DataFrameReader.parquet_blocks = parquet_blocks
DataFrameReader.parquet_block_columns = parquet_block_columns
DataFrameReader.parquet_partitions = parquet_partitions
if has_connect:
ConnectDataFrameReader.parquet_metadata = parquet_metadata
ConnectDataFrameReader.parquet_schema = parquet_schema
ConnectDataFrameReader.parquet_blocks = parquet_blocks
ConnectDataFrameReader.parquet_block_columns = parquet_block_columns
ConnectDataFrameReader.parquet_partitions = parquet_partitions
================================================
FILE: python/pyproject.toml
================================================
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
================================================
FILE: python/pyspark/jars/.gitignore
================================================
# Ignore everything in this directory
*
# Except this file
!.gitignore
================================================
FILE: python/setup.py
================================================
#!/usr/bin/env python3
# Copyright 2023 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import subprocess
import sys
from pathlib import Path
from setuptools import setup
from setuptools.command.sdist import sdist
jar_version = '2.16.0-3.5-SNAPSHOT'
scala_version = '2.13.8'
scala_compat_version = '.'.join(scala_version.split('.')[:2])
spark_compat_version = jar_version.split('-')[1]
jar_file = f"spark-extension_{scala_compat_version}-{jar_version}.jar"
version = jar_version.replace('SNAPSHOT', 'dev0').replace('-', '.')
# read the contents of the README.md file
long_description = (Path(__file__).parent / "README.md").read_text()
class custom_sdist(sdist):
def make_distribution(self):
# build jar file via mvn if it does not exist
# then copy the jar file from target/ into python/pyspark/jars/
project_root = Path(__file__).parent.parent
jar_src_path = project_root / "target" / jar_file
jar_dst_path = project_root / "python" / "pyspark" / "jars" / jar_file
if not jar_src_path.exists():
# first set version for scala sources
set_version_command = ["./set-version.sh", f"{spark_compat_version}.0", scala_version]
# then package Scala sources
mvn_command = ["mvn", "--batch-mode", "package", "-Dspotless.check.skip", "-DskipTests", "-Dmaven.test.skip=true"]
print(' '.join(set_version_command))
try:
subprocess.check_call(set_version_command, cwd=str(project_root.absolute()))
except OSError as e:
raise RuntimeError(f'setting versions failed: {e}')
print(f"building {jar_src_path}")
print(' '.join(mvn_command))
try:
subprocess.check_call(mvn_command, cwd=str(project_root.absolute()))
except OSError as e:
raise RuntimeError(f'mvn command failed: {e}')
if not jar_src_path.exists():
print(f"Building jar file succeeded but file does still not exist: {jar_src_path}")
sys.exit(1)
print(f"copying {jar_src_path} -> {jar_dst_path}")
jar_dst_path.parent.mkdir(exist_ok=True)
shutil.copy2(jar_src_path, jar_dst_path)
self._add_data_files([("pyspark.jars", "pyspark/jars", ".", [jar_file])])
sdist.make_distribution(self)
setup(
name="pyspark-extension",
version=version,
description="A library that provides useful extensions to Apache Spark.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Enrico Minack",
author_email="github@enrico.minack.dev",
url="https://github.com/G-Research/spark-extension",
cmdclass={'sdist': custom_sdist},
install_requires=["typing_extensions"],
extras_require={
"test": [
"pandas>=1.0.5",
"py4j",
"pyarrow>=4.0.0",
f"pyspark~={spark_compat_version}.0",
"pytest",
"unittest-xml-reporting",
],
},
packages=[
"gresearch",
"gresearch.spark",
"gresearch.spark.diff",
"gresearch.spark.diff.comparator",
"gresearch.spark.parquet",
"pyspark.jars",
],
include_package_data=False,
package_data={
"pyspark.jars": [jar_file],
},
license="http://www.apache.org/licenses/LICENSE-2.0.html",
python_requires=">=3.7",
classifiers=[
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Typing :: Typed",
],
)
================================================
FILE: python/test/__init__.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: python/test/spark_common.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import unittest
from contextlib import contextmanager
from pathlib import Path
from pyspark import SparkConf
from pyspark.sql import SparkSession
logger = logging.getLogger()
logger.level = logging.INFO
@contextmanager
def spark_session():
session = SparkTest.get_spark_session()
try:
yield session
finally:
session.stop()
class SparkTest(unittest.TestCase):
@staticmethod
def main(file: str):
if len(sys.argv) == 2:
# location to store test results provided, this requires package unittest-xml-reporting
import xmlrunner
unittest.main(
module=f'test.{Path(file).name[:-3]}',
testRunner=xmlrunner.XMLTestRunner(output=sys.argv[1]),
argv=sys.argv[:1],
# these make sure that some options that are not applicable
# remain hidden from the help menu.
failfast=False, buffer=False, catchbreak=False
)
else:
unittest.main()
@staticmethod
def get_pom_path() -> str:
paths = ['.', '..', os.path.join('..', '..')]
for path in paths:
if os.path.exists(os.path.join(path, 'pom.xml')):
return path
raise RuntimeError('Could not find path to pom.xml, looked here: {}'.format(', '.join(paths)))
@staticmethod
def get_spark_config(path) -> SparkConf:
master = 'local[2]'
conf = SparkConf().setAppName('unit test').setMaster(master)
return conf.setAll([
('spark.ui.showConsoleProgress', 'false'),
('spark.test.home', os.environ.get('SPARK_HOME')),
('spark.locality.wait', '0'),
('spark.driver.extraClassPath', '{}'.format(':'.join([
os.path.join(os.getcwd(), path, 'target', 'classes'),
os.path.join(os.getcwd(), path, 'target', 'test-classes'),
]))),
])
@classmethod
def get_spark_session(cls) -> SparkSession:
builder = SparkSession.builder
if 'TEST_SPARK_CONNECT_SERVER' in os.environ:
builder.remote(os.environ['TEST_SPARK_CONNECT_SERVER'])
elif 'PYSPARK_GATEWAY_PORT' in os.environ:
logging.info('Running inside existing Spark environment')
else:
logging.info('Setting up Spark environment')
# setting conf spark.pyspark.python does not work
os.environ['PYSPARK_PYTHON'] = sys.executable
path = cls.get_pom_path()
conf = cls.get_spark_config(path)
builder.config(conf=conf)
return builder.getOrCreate()
spark: SparkSession = None
is_spark_connect: bool = 'TEST_SPARK_CONNECT_SERVER' in os.environ
@classmethod
def setUpClass(cls):
super(SparkTest, cls).setUpClass()
logging.info('launching Spark session')
cls.spark = cls.get_spark_session()
@classmethod
def tearDownClass(cls):
logging.info('stopping Spark session')
cls.spark.stop()
super(SparkTest, cls).tearDownClass()
@contextmanager
def sql_conf(self, pairs):
"""
Copied from pyspark/testing/sqlutils available from PySpark 3.5.0 and higher.
https://github.com/apache/spark/blob/v3.5.0/python/pyspark/testing/sqlutils.py#L171
http://www.apache.org/licenses/LICENSE-2.0
A convenient context manager to test some configuration specific logic. This sets
`value` to the configuration `key` and then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
keys = pairs.keys()
new_values = pairs.values()
old_values = [self.spark.conf.get(key, None) for key in keys]
for key, new_value in zip(keys, new_values):
self.spark.conf.set(key, new_value)
try:
yield
finally:
for key, old_value in zip(keys, old_values):
if old_value is None:
self.spark.conf.unset(key)
else:
self.spark.conf.set(key, old_value)
================================================
FILE: python/test/test_diff.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import re
from py4j.java_gateway import JavaObject
from pyspark.sql import Row
from pyspark.sql.functions import col, when, abs
from pyspark.sql.types import IntegerType, LongType, StringType, DateType, StructField, StructType, FloatType, DoubleType
from unittest import skipIf
from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators, diffwith
from spark_common import SparkTest
class DiffTest(SparkTest):
expected_diff = None
@contextlib.contextmanager
def assert_requirement(self, error_message: str):
with self.assertRaises(ValueError) as e:
yield
self.assertEqual((error_message,), e.exception.args)
@classmethod
def setUpClass(cls):
super(DiffTest, cls).setUpClass()
value_row = Row('id', 'val', 'label')
cls.left_df = cls.spark.createDataFrame([
value_row(1, 1.0, 'one'),
value_row(2, 2.0, 'two'),
value_row(3, 3.0, 'three'),
value_row(4, None, None),
value_row(5, 5.0, 'five'),
value_row(7, 7.0, 'seven'),
])
cls.right_df = cls.spark.createDataFrame([
value_row(1, 1.1, 'one'),
value_row(2, 2.0, 'Two'),
value_row(3, 3.0, 'three'),
value_row(4, 4.0, 'four'),
value_row(5, None, None),
value_row(6, 6.0, 'six'),
])
diff_row = Row('diff', 'id', 'left_val', 'right_val', 'left_label', 'right_label')
cls.expected_diff = [
diff_row('C', 1, 1.0, 1.1, 'one', 'one'),
diff_row('C', 2, 2.0, 2.0, 'two', 'Two'),
diff_row('N', 3, 3.0, 3.0, 'three', 'three'),
diff_row('C', 4, None, 4.0, None, 'four'),
diff_row('C', 5, 5.0, None, 'five', None),
diff_row('I', 6, None, 6.0, None, 'six'),
diff_row('D', 7, 7.0, None, 'seven', None),
]
diff_change_row = Row('diff', 'change', 'id', 'left_val', 'right_val', 'left_label', 'right_label')
cls.expected_diff_change = [
diff_change_row('C', ['val'], 1, 1.0, 1.1, 'one', 'one'),
diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),
diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'),
diff_change_row('C', ['val', 'label'], 4, None, 4.0, None, 'four'),
diff_change_row('C', ['val', 'label'], 5, 5.0, None, 'five', None),
diff_change_row('I', None, 6, None, 6.0, None, 'six'),
diff_change_row('D', None, 7, 7.0, None, 'seven', None),
]
cls.expected_diff_reversed = [
diff_row('C', 1, 1.1, 1.0, 'one', 'one'),
diff_row('C', 2, 2.0, 2.0, 'Two', 'two'),
diff_row('N', 3, 3.0, 3.0, 'three', 'three'),
diff_row('C', 4, 4.0, None, 'four', None),
diff_row('C', 5, None, 5.0, None, 'five'),
diff_row('D', 6, 6.0, None, 'six', None),
diff_row('I', 7, None, 7.0, None, 'seven'),
]
cls.expected_diff_ignored = [
diff_row('C', 1, 1.0, 1.1, 'one', 'one'),
diff_row('N', 2, 2.0, 2.0, 'two', 'Two'),
diff_row('N', 3, 3.0, 3.0, 'three', 'three'),
diff_row('C', 4, None, 4.0, None, 'four'),
diff_row('C', 5, 5.0, None, 'five', None),
diff_row('I', 6, None, 6.0, None, 'six'),
diff_row('D', 7, 7.0, None, 'seven', None),
]
diffwith_row = Row('diff', 'left', 'right')
cls.expected_diffwith = [
diffwith_row('C', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),
diffwith_row('C', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),
diffwith_row('N', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),
diffwith_row('C', value_row(4, None, None), value_row(4, 4.0, 'four')),
diffwith_row('C', value_row(5, 5.0, 'five'), value_row(5, None, None)),
diffwith_row('I', None, value_row(6, 6.0, 'six')),
diffwith_row('D', value_row(7, 7.0, 'seven'), None),
]
diffwith_with_options_row = Row('d', 'l_val', 'r_val')
cls.expected_diffwith_with_options = [
diffwith_with_options_row('c', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),
diffwith_with_options_row('c', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),
diffwith_with_options_row('n', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),
diffwith_with_options_row('c', value_row(4, None, None), value_row(4, 4.0, 'four')),
diffwith_with_options_row('c', value_row(5, 5.0, 'five'), value_row(5, None, None)),
diffwith_with_options_row('i', None, value_row(6, 6.0, 'six')),
diffwith_with_options_row('r', value_row(7, 7.0, 'seven'), None),
]
cls.expected_diffwith_ignored = [
diffwith_row('C', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),
diffwith_row('N', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),
diffwith_row('N', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),
diffwith_row('C', value_row(4, None, None), value_row(4, 4.0, 'four')),
diffwith_row('C', value_row(5, 5.0, 'five'), value_row(5, None, None)),
diffwith_row('I', None, value_row(6, 6.0, 'six')),
diffwith_row('D', value_row(7, 7.0, 'seven'), None),
]
diff_with_options_row = Row('d', 'id', 'l_val', 'r_val', 'l_label', 'r_label')
cls.expected_diff_with_options = [
diff_with_options_row('c', 1, 1.0, 1.1, 'one', 'one'),
diff_with_options_row('c', 2, 2.0, 2.0, 'two', 'Two'),
diff_with_options_row('n', 3, 3.0, 3.0, 'three', 'three'),
diff_with_options_row('c', 4, None, 4.0, None, 'four'),
diff_with_options_row('c', 5, 5.0, None, 'five', None),
diff_with_options_row('i', 6, None, 6.0, None, 'six'),
diff_with_options_row('r', 7, 7.0, None, 'seven', None),
]
cls.expected_diff_with_options_ignored = [
diff_with_options_row('c', 1, 1.0, 1.1, 'one', 'one'),
diff_with_options_row('n', 2, 2.0, 2.0, 'two', 'Two'),
diff_with_options_row('n', 3, 3.0, 3.0, 'three', 'three'),
diff_with_options_row('c', 4, None, 4.0, None, 'four'),
diff_with_options_row('c', 5, 5.0, None, 'five', None),
diff_with_options_row('i', 6, None, 6.0, None, 'six'),
diff_with_options_row('r', 7, 7.0, None, 'seven', None),
]
diff_with_changes_row = Row('diff', 'changes', 'id', 'left_val', 'right_val', 'left_label', 'right_label')
cls.expected_diff_with_changes = [
diff_with_changes_row('C', ['val'], 1, 1.0, 1.1, 'one', 'one'),
diff_with_changes_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),
diff_with_changes_row('N', [], 3, 3.0, 3.0, 'three', 'three'),
diff_with_changes_row('C', ['val', 'label'], 4, None, 4.0, None, 'four'),
diff_with_changes_row('C', ['val', 'label'], 5, 5.0, None, 'five', None),
diff_with_changes_row('I', None, 6, None, 6.0, None, 'six'),
diff_with_changes_row('D', None, 7, 7.0, None, 'seven', None),
]
cls.expected_diff_in_column_by_column_mode = cls.expected_diff
diff_in_side_by_side_mode_row = Row('diff', 'id', 'left_val', 'left_label', 'right_val', 'right_label')
cls.expected_diff_in_side_by_side_mode = [
diff_in_side_by_side_mode_row('C', 1, 1.0, 'one', 1.1, 'one'),
diff_in_side_by_side_mode_row('C', 2, 2.0, 'two', 2.0, 'Two'),
diff_in_side_by_side_mode_row('N', 3, 3.0, 'three', 3.0, 'three'),
diff_in_side_by_side_mode_row('C', 4, None, None, 4.0, 'four'),
diff_in_side_by_side_mode_row('C', 5, 5.0, 'five', None, None),
diff_in_side_by_side_mode_row('I', 6, None, None, 6.0, 'six'),
diff_in_side_by_side_mode_row('D', 7, 7.0, 'seven', None, None),
]
diff_in_left_side_mode_row = Row('diff', 'id', 'left_val', 'left_label')
cls.expected_diff_in_left_side_mode = [
diff_in_left_side_mode_row('C', 1, 1.0, 'one'),
diff_in_left_side_mode_row('C', 2, 2.0, 'two'),
diff_in_left_side_mode_row('N', 3, 3.0, 'three'),
diff_in_left_side_mode_row('C', 4, None, None),
diff_in_left_side_mode_row('C', 5, 5.0, 'five'),
diff_in_left_side_mode_row('I', 6, None, None),
diff_in_left_side_mode_row('D', 7, 7.0, 'seven'),
]
diff_in_right_side_mode_row = Row('diff', 'id', 'right_val', 'right_label')
cls.expected_diff_in_right_side_mode = [
diff_in_right_side_mode_row('C', 1, 1.1, 'one'),
diff_in_right_side_mode_row('C', 2, 2.0, 'Two'),
diff_in_right_side_mode_row('N', 3, 3.0, 'three'),
diff_in_right_side_mode_row('C', 4, 4.0, 'four'),
diff_in_right_side_mode_row('C', 5, None, None),
diff_in_right_side_mode_row('I', 6, 6.0, 'six'),
diff_in_right_side_mode_row('D', 7, None, None),
]
diff_in_sparse_mode_row = Row('diff', 'id', 'left_val', 'right_val', 'left_label', 'right_label')
cls.expected_diff_in_sparse_mode = [
diff_in_sparse_mode_row('C', 1, 1.0, 1.1, None, None),
diff_in_sparse_mode_row('C', 2, None, None, 'two', 'Two'),
diff_in_sparse_mode_row('N', 3, None, None, None, None),
diff_in_sparse_mode_row('C', 4, None, 4.0, None, 'four'),
diff_in_sparse_mode_row('C', 5, 5.0, None, 'five', None),
diff_in_sparse_mode_row('I', 6, None, 6.0, None, 'six'),
diff_in_sparse_mode_row('D', 7, 7.0, None, 'seven', None),
]
def test_check_schema(self):
with self.subTest("duplicate columns"):
with self.assert_requirement("The datasets have duplicate columns.\n"
"Left column names: id, id\nRight column names: id, id"):
self.left_df.select("id", "id").diff(self.right_df.select("id", "id"), "id")
with self.subTest("case-sensitive id column"):
with self.assert_requirement("Some id columns do not exist: ID missing among id, val, label"):
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
self.left_df.diff(self.right_df, "ID")
left = self.left_df.withColumnRenamed("val", "diff")
right = self.right_df.withColumnRenamed("val", "diff")
with self.subTest("id column 'diff'"):
with self.assert_requirement("The id columns must not contain the diff column name 'diff': id, diff, label"):
left.diff(right)
with self.assert_requirement("The id columns must not contain the diff column name 'diff': diff"):
left.diff(right, "diff")
with self.assert_requirement("The id columns must not contain the diff column name 'diff': diff, id"):
left.diff(right, "diff", "id")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.assert_requirement("The id columns must not contain the diff column name 'diff': Diff, id"):
left.withColumnRenamed("diff", "Diff") \
.diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id")
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
left.withColumnRenamed("diff", "Diff") \
.diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id")
with self.subTest("non-id column 'diff"):
actual = left.diff(right, "id").orderBy("id")
expected_columns = ["diff", "id", "left_diff", "right_diff", "left_label", "right_label"]
self.assertEqual(actual.columns, expected_columns)
self.assertEqual(actual.collect(), self.expected_diff)
with self.subTest("non-id column produces diff column name"):
options = DiffOptions() \
.with_diff_column("a_val") \
.with_left_column_prefix("a") \
.with_right_column_prefix("b")
with self.assert_requirement("The column prefixes 'a' and 'b', together with these non-id columns " +
"must not produce the diff column name 'a_val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.assert_requirement("The column prefixes 'a' and 'b', together with these non-id columns " +
"must not produce the diff column name 'b_val': val, label"):
self.left_df.diff(self.right_df, options.with_diff_column("b_val"), "id")
with self.subTest("non-id column would produce diff column name unless in left-side mode"):
options = DiffOptions() \
.with_diff_column("a_val") \
.with_left_column_prefix("a") \
.with_right_column_prefix("b") \
.with_diff_mode(DiffMode.LeftSide)
self.left_df.diff(self.right_df, options, "id")
with self.subTest("non-id column would produce diff column name unless in right-side mode"):
options = DiffOptions() \
.with_diff_column("b_val") \
.with_left_column_prefix("a") \
.with_right_column_prefix("b") \
.with_diff_mode(DiffMode.RightSide)
self.left_df.diff(self.right_df, options, "id")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.subTest("case-insensitive non-id column produces diff column name"):
options = DiffOptions() \
.with_diff_column("a_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("b")
with self.assert_requirement("The column prefixes 'A' and 'b', together with these non-id columns " +
"must not produce the diff column name 'a_val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.subTest("case-insensitive non-id column would produce diff column name unless in left-side mode"):
options = DiffOptions() \
.with_diff_column("a_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("B") \
.with_diff_mode(DiffMode.LeftSide)
self.left_df.diff(self.right_df, options, "id")
with self.subTest("case-insensitive non-id column would produce diff column name unless in right-side mode"):
options = DiffOptions() \
.with_diff_column("b_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("B") \
.with_diff_mode(DiffMode.RightSide)
self.left_df.diff(self.right_df, options, "id")
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
with self.subTest("case-sensitive non-id column produces non-conflicting diff column name"):
options = DiffOptions() \
.with_diff_column("a_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("B") \
actual = self.left_df.diff(self.right_df, options, "id").orderBy("id")
expected_columns = ["a_val", "id", "A_val", "B_val", "A_label", "B_label"]
self.assertEqual(actual.columns, expected_columns)
self.assertEqual(actual.collect(), self.expected_diff)
left = self.left_df.withColumnRenamed("val", "change")
right = self.right_df.withColumnRenamed("val", "change")
with self.subTest("id column 'change'"):
options = DiffOptions() \
.with_change_column("change")
with self.assert_requirement("The id columns must not contain the change column name 'change': id, change, label"):
left.diff(right, options)
with self.assert_requirement("The id columns must not contain the change column name 'change': change"):
left.diff(right, options, "change")
with self.assert_requirement("The id columns must not contain the change column name 'change': change, id"):
left.diff(right, options, "change", "id")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.assert_requirement("The id columns must not contain the change column name 'change': Change, id"):
left.withColumnRenamed("change", "Change") \
.diff(right.withColumnRenamed("change", "Change"), options, "Change", "id")
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
left.withColumnRenamed("change", "Change") \
.diff(right.withColumnRenamed("change", "Change"), options, "Change", "id")
with self.subTest("non-id column 'change"):
actual = left.diff(right, options, "id").orderBy("id")
expected_columns = ["diff", "change", "id", "left_change", "right_change", "left_label", "right_label"]
diff_change_row = Row(*expected_columns)
expected_diff = [
diff_change_row('C', ['change'], 1, 1.0, 1.1, 'one', 'one'),
diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),
diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'),
diff_change_row('C', ['change', 'label'], 4, None, 4.0, None, 'four'),
diff_change_row('C', ['change', 'label'], 5, 5.0, None, 'five', None),
diff_change_row('I', None, 6, None, 6.0, None, 'six'),
diff_change_row('D', None, 7, 7.0, None, 'seven', None),
]
self.assertEqual(actual.columns, expected_columns)
self.assertEqual(actual.collect(), expected_diff)
with self.subTest("non-id column produces change column name"):
options = DiffOptions() \
.with_change_column("a_val") \
.with_left_column_prefix("a") \
.with_right_column_prefix("b")
with self.assert_requirement("The column prefixes 'a' and 'b', together with these non-id columns " +
"must not produce the change column name 'a_val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.subTest("case-insensitive non-id column produces change column name"):
options = DiffOptions() \
.with_change_column("a_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("B")
with self.assert_requirement("The column prefixes 'A' and 'B', together with these non-id columns " +
"must not produce the change column name 'a_val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
with self.subTest("case-sensitive non-id column produces non-conflicting change column name"):
options = DiffOptions() \
.with_change_column("a_val") \
.with_left_column_prefix("A") \
.with_right_column_prefix("B")
actual = self.left_df.diff(self.right_df, options, "id").orderBy("id")
expected_columns = ["diff", "a_val", "id", "A_val", "B_val", "A_label", "B_label"]
self.assertEqual(actual.columns, expected_columns)
self.assertEqual(actual.collect(), self.expected_diff_change)
left = self.left_df.select(col("id").alias("first_id"), col("val").alias("id"), "label")
right = self.right_df.select(col("id").alias("first_id"), col("val").alias("id"), "label")
with self.subTest("non-id column produces id column name"):
options = DiffOptions() \
.with_left_column_prefix("first") \
.with_right_column_prefix("second")
with self.assert_requirement("The column prefixes 'first' and 'second', together with these non-id columns " +
"must not produce any id column name 'first_id': id, label"):
left.diff(right, options, "first_id")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.subTest("case-insensitive non-id column produces id column name"):
options = DiffOptions() \
.with_left_column_prefix("FIRST") \
.with_right_column_prefix("SECOND")
with self.assert_requirement("The column prefixes 'FIRST' and 'SECOND', together with these non-id columns " +
"must not produce any id column name 'first_id': id, label"):
left.diff(right, options, "first_id")
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
with self.subTest("case-sensitive non-id column produces non-conflicting id column name"):
options = DiffOptions() \
.with_left_column_prefix("FIRST") \
.with_right_column_prefix("SECOND")
actual = left.diff(right, options, "first_id").orderBy("first_id")
expected_columns = ["diff", "first_id", "FIRST_id", "SECOND_id", "FIRST_label", "SECOND_label"]
self.assertEqual(actual.columns, expected_columns)
self.assertEqual(actual.collect(), self.expected_diff)
with self.subTest("empty schema"):
with self.assert_requirement("The schema must not be empty"):
self.left_df.select().diff(self.right_df.select())
with self.subTest("empty schema after ignored columns"):
with self.assert_requirement("The schema except ignored columns must not be empty"):
self.left_df.select("id", "val").diff(self.right_df.select("id", "label"), [], ["id", "val", "label"])
with self.subTest("different types"):
with self.assert_requirement("The datasets do not have the same schema.\n" +
"Left extra columns: val (double)\n" +
"Right extra columns: val (string)"):
self.left_df.select("id", "val").diff(self.right_df.select("id", col("label").alias("val")))
with self.subTest("ignore columns with different types"):
actual = self.left_df.select("id", "val").diff(self.right_df.select("id", col("label").alias("val")), [], ["val"])
expected_schema = [
("diff", StringType()),
("id", LongType()),
("left_val", DoubleType()),
("right_val", StringType()),
]
self.assertEqual([(f.name, f.dataType) for f in actual.schema], expected_schema)
with self.subTest("diff with different column names"):
with self.assert_requirement("The datasets do not have the same schema.\n" +
"Left extra columns: val (double)\n" +
"Right extra columns: label (string)"):
self.left_df.select("id", "val").diff(self.right_df.select("id", "label"))
left = self.left_df.select("id", "val", "label")
right = self.right_df.select(col("id").alias("ID"), col("val").alias("VaL"), "label")
with self.sql_conf({"spark.sql.caseSensitive": "false"}):
with self.subTest("case-insensitive column names"):
actual = left.diff(right, "id").orderBy("id")
reverse = right.diff(left, "id").orderBy("id")
self.assertEqual(actual.columns, ["diff", "id", "left_val", "right_VaL", "left_label", "right_label"])
self.assertEqual(actual.collect(), self.expected_diff)
self.assertEqual(reverse.columns, ["diff", "id", "left_VaL", "right_val", "left_label", "right_label"])
self.assertEqual(reverse.collect(), self.expected_diff_reversed)
with self.sql_conf({"spark.sql.caseSensitive": "true"}):
with self.subTest("case-sensitive column names"):
with self.assert_requirement("The datasets do not have the same schema.\n" +
"Left extra columns: id (long), val (double)\n" +
"Right extra columns: ID (long), VaL (double)"):
left.diff(right, "id")
with self.subTest("non-existing id column"):
with self.assert_requirement("Some id columns do not exist: does not exists missing among id, val, label"):
self.left_df.diff(self.right_df, "does not exists")
with self.subTest("different number of columns"):
with self.assert_requirement("The number of columns doesn't match.\n" +
"Left column names (2): id, val\n" +
"Right column names (3): id, val, label"):
self.left_df.select("id", "val").diff(self.right_df, "id")
with self.subTest("different number of columns after ignoring columns"):
left = self.left_df.select("id", "val", col("label").alias("meta"))
right = self.right_df.select("id", col("label").alias("seq"), "val")
with self.assert_requirement("The number of columns doesn't match.\n" +
"Left column names except ignored columns (2): id, val\n" +
"Right column names except ignored columns (3): id, seq, val"):
left.diff(right, ["id"], ["meta"])
with self.subTest("diff column name in value columns in left-side diff mode"):
options = DiffOptions().with_diff_column("val").with_diff_mode(DiffMode.LeftSide)
with self.assert_requirement("The left non-id columns must not contain the diff column name 'val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.subTest("diff column name in value columns in right-side diff mode"):
options = DiffOptions().with_diff_column("val").with_diff_mode(DiffMode.RightSide)
with self.assert_requirement("The right non-id columns must not contain the diff column name 'val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.subTest("change column name in value columns in left-side diff mode"):
options = DiffOptions().with_change_column("val").with_diff_mode(DiffMode.LeftSide)
with self.assert_requirement("The left non-id columns must not contain the change column name 'val': val, label"):
self.left_df.diff(self.right_df, options, "id")
with self.subTest("change column name in value columns in right-side diff mode"):
options = DiffOptions().with_change_column("val").with_diff_mode(DiffMode.RightSide)
with self.assert_requirement("The right non-id columns must not contain the change column name 'val': val, label"):
self.left_df.diff(self.right_df, options, "id")
def test_dataframe_diff(self):
diff = self.left_df.diff(self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
def test_dataframe_diff_with_ids_ignored(self):
diff = self.left_df.diff(self.right_df, ['id'], ['label']).orderBy('id').collect()
self.assertEqual(self.expected_diff_ignored, diff)
def test_dataframe_diff_with_wrong_argument_types(self):
with self.subTest("id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: int"):
self.left_df.diff(self.right_df, 1)
with self.subTest("one of two id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: str, int"):
self.left_df.diff(self.right_df, "id", 1)
with self.subTest("one of three id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: str, int, str"):
self.left_df.diff(self.right_df, "id", 1, "val")
with self.subTest("id columns is not list"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: int, list"):
self.left_df.diff(self.right_df, 1, ['val'])
with self.subTest("one of id columns is not string"):
with self.assert_requirement("The id_columns must all be strings: str, int"):
self.left_df.diff(self.right_df, ['id', 1], ['val'])
with self.subTest("ignore columns is not list"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list, int"):
self.left_df.diff(self.right_df, ['id'], 1)
with self.subTest("one of ignore columns is not string"):
with self.assert_requirement("The ignore_columns must all be strings: str, int"):
self.left_df.diff(self.right_df, ['id'], ['val', 1])
with self.subTest("one list of string id columns"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list"):
self.left_df.diff(self.right_df, ['id'])
with self.subTest("three lists of string id and ignore columns"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list, list, list"):
self.left_df.diff(self.right_df, ['id'], ['val'], ['three'])
with self.subTest("options not second argument"):
with self.assert_requirement("Diff options must be given as second argument"):
self.left_df.diff(self.right_df, 'id', DiffOptions())
def test_dataframe_diffwith(self):
diff = self.left_df.diffwith(self.right_df, 'id').orderBy('id').collect()
self.assertSetEqual(set(self.expected_diffwith), set(diff))
self.assertEqual(len(self.expected_diffwith), len(diff))
def test_dataframe_diffwith_with_default_options(self):
diff = self.left_df.diffwith(self.right_df, DiffOptions(), 'id').orderBy('id').collect()
self.assertSetEqual(set(self.expected_diffwith), set(diff))
self.assertEqual(len(self.expected_diffwith), len(diff))
def test_dataframe_diffwith_with_options(self):
options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)
diff = self.left_df.diffwith(self.right_df, options, 'id').orderBy('id').collect()
self.assertSetEqual(set(self.expected_diffwith_with_options), set(diff))
self.assertEqual(len(self.expected_diffwith_with_options), len(diff))
def test_dataframe_diffwith_with_ignored(self):
diff = self.left_df.diffwith(self.right_df, ['id'], ['label']).orderBy('id').collect()
self.assertSetEqual(set(self.expected_diffwith_ignored), set(diff))
self.assertEqual(len(self.expected_diffwith_ignored), len(diff))
def test_dataframe_diffwith_with_wrong_argument_types(self):
with self.subTest("id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: int"):
self.left_df.diffwith(self.right_df, 1)
with self.subTest("one of two id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: str, int"):
self.left_df.diffwith(self.right_df, "id", 1)
with self.subTest("one of three id columns is not string"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: str, int, str"):
self.left_df.diffwith(self.right_df, "id", 1, "val")
with self.subTest("id columns is not list"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: int, list"):
self.left_df.diffwith(self.right_df, 1, ['val'])
with self.subTest("one of id columns is not string"):
with self.assert_requirement("The id_columns must all be strings: str, int"):
self.left_df.diffwith(self.right_df, ['id', 1], ['val'])
with self.subTest("ignore columns is not list"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list, int"):
self.left_df.diffwith(self.right_df, ['id'], 1)
with self.subTest("one of ignore columns is not string"):
with self.assert_requirement("The ignore_columns must all be strings: str, int"):
self.left_df.diffwith(self.right_df, ['id'], ['val', 1])
with self.subTest("one list of string id columns"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list"):
self.left_df.diffwith(self.right_df, ['id'])
with self.subTest("three lists of string id and ignore columns"):
with self.assert_requirement("The id_or_ignore_columns argument must either all be strings "
"or exactly two iterables of strings: list, list, list"):
self.left_df.diffwith(self.right_df, ['id'], ['val'], ['three'])
with self.subTest("options not second argument"):
with self.assert_requirement("Diff options must be given as second argument"):
self.left_df.diffwith(self.right_df, 'id', DiffOptions())
def test_dataframe_diff_with_default_options(self):
diff = self.left_df.diff(self.right_df, DiffOptions(), 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
diff = self.left_df.diff_with_options(self.right_df, DiffOptions(), 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
def test_dataframe_diff_with_options(self):
options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_options, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_options, diff)
def test_dataframe_diff_with_options_and_ignored(self):
options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)
diff = self.left_df.diff(self.right_df, options, ['id'], ['label']).orderBy('id').collect()
self.assertEqual(self.expected_diff_with_options_ignored, diff)
diff = self.left_df.diff_with_options(self.right_df, options, ['id'], ['label']).orderBy('id').collect()
self.assertEqual(self.expected_diff_with_options_ignored, diff)
def test_dataframe_diff_with_changes(self):
options = DiffOptions().with_change_column('changes')
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_changes, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_changes, diff)
def test_dataframe_diff_with_diff_mode_column_by_column(self):
options = DiffOptions().with_diff_mode(DiffMode.ColumnByColumn)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)
def test_dataframe_diff_with_diff_mode_side_by_side(self):
options = DiffOptions().with_diff_mode(DiffMode.SideBySide)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)
def test_dataframe_diff_with_diff_mode_left_side(self):
options = DiffOptions().with_diff_mode(DiffMode.LeftSide)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_left_side_mode, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_left_side_mode, diff)
def test_dataframe_diff_with_diff_mode_right_side(self):
options = DiffOptions().with_diff_mode(DiffMode.RightSide)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_right_side_mode, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_right_side_mode, diff)
def test_dataframe_diff_with_sparse_mode(self):
options = DiffOptions().with_sparse_mode(True)
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_sparse_mode, diff)
diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_sparse_mode, diff)
def test_differ_diff(self):
diff = Differ().diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
def test_differ_diffwith(self):
diff = Differ().diffwith(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertSetEqual(set(self.expected_diffwith), set(diff))
self.assertEqual(len(self.expected_diffwith), len(diff))
def test_differ_diff_with_default_options(self):
options = DiffOptions()
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
def test_differ_diff_with_options(self):
options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_options, diff)
def test_differ_diff_with_changes(self):
options = DiffOptions().with_change_column('changes')
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_with_changes, diff)
def test_differ_diff_in_diff_mode_column_by_column(self):
options = DiffOptions().with_diff_mode(DiffMode.ColumnByColumn)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)
def test_differ_diff_in_diff_mode_side_by_side(self):
options = DiffOptions().with_diff_mode(DiffMode.SideBySide)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)
def test_differ_diff_in_diff_mode_left_side(self):
options = DiffOptions().with_diff_mode(DiffMode.LeftSide)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_left_side_mode, diff)
def test_differ_diff_in_diff_mode_right_side(self):
options = DiffOptions().with_diff_mode(DiffMode.RightSide)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_right_side_mode, diff)
def test_differ_diff_with_sparse_mode(self):
options = DiffOptions().with_sparse_mode(True)
diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff_in_sparse_mode, diff)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM")
def test_diff_options_default(self):
jvm = self.spark._jvm
joptions = jvm.uk.co.gresearch.spark.diff.DiffOptions.default()
options = DiffOptions()
for attr in options.__dict__.keys():
const = re.sub(r'_(.)', lambda match: match.group(1).upper(), attr)
expected = getattr(joptions, const)()
actual = getattr(options, attr)
if type(expected) == JavaObject:
class_name = re.sub(r'\$.*$', '', expected.getClass().getName())
if class_name in ['scala.None']: # how does the Some(?) look like?
actual = 'Some({})'.format(actual) if actual is not None else 'None'
if class_name in ['scala.collection.immutable.Map', 'scala.collection.mutable.Map']:
actual = f'Map({", ".join(f"{key} -> {value._to_java(jvm).toString()}" for key, value in actual.items())})'
expected = expected.toString()
if attr in ['diff_mode', 'default_comparator']:
# does the Python default diff mode resolve to the same Java diff mode enum value?
# does the Python diff comparator resolve to the same Java diff comparator?
self.assertEqual(expected, actual._to_java(jvm).toString(), '{} == {} ?'.format(attr, const))
else:
self.assertEqual(expected, actual, '{} == {} ?'.format(attr, const))
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM")
def test_diff_mode_consts(self):
jvm = self.spark._jvm
jmodes = jvm.uk.co.gresearch.spark.diff.DiffMode
modes = DiffMode
for attr in modes.__dict__.keys():
if attr[0] != '_':
actual = getattr(modes, attr)
if isinstance(actual, DiffMode) and actual != DiffMode.Default:
expected = getattr(jmodes, attr)()
self.assertEqual(expected.toString(), actual.name, actual.name)
self.assertIsNotNone(DiffMode.Default.name, jmodes.Default().toString())
def test_diff_options_comparator_for(self):
cmp1 = DiffComparators.default()
cmp2 = DiffComparators.epsilon(0.01)
cmp3 = DiffComparators.string()
opts = DiffOptions() \
.with_column_name_comparator(cmp1, "abc", "def") \
.with_data_type_comparator(cmp2, LongType()) \
.with_default_comparator(cmp3)
self.assertEqual(opts.comparator_for(StructField("abc", IntegerType())), cmp1)
self.assertEqual(opts.comparator_for(StructField("def", LongType())), cmp1)
self.assertEqual(opts.comparator_for(StructField("ghi", LongType())), cmp2)
self.assertEqual(opts.comparator_for(StructField("jkl", IntegerType())), cmp3)
def test_diff_fluent_setters(self):
cmp1 = DiffComparators.default()
cmp2 = DiffComparators.epsilon(0.01)
cmp3 = DiffComparators.string()
cmp4 = DiffComparators.duration('PT24H')
default = DiffOptions()
options = default \
.with_diff_column('d') \
.with_left_column_prefix('l') \
.with_right_column_prefix('r') \
.with_insert_diff_value('i') \
.with_change_diff_value('c') \
.with_delete_diff_value('r') \
.with_nochange_diff_value('n') \
.with_change_column('c') \
.with_diff_mode(DiffMode.SideBySide) \
.with_sparse_mode(True) \
.with_default_comparator(cmp1) \
.with_data_type_comparator(cmp2, IntegerType()) \
.with_data_type_comparator(cmp3, StringType()) \
.with_column_name_comparator(cmp4, 'value')
self.assertEqual(options.diff_column, 'd')
self.assertEqual(options.left_column_prefix, 'l')
self.assertEqual(options.right_column_prefix, 'r')
self.assertEqual(options.insert_diff_value, 'i')
self.assertEqual(options.change_diff_value, 'c')
self.assertEqual(options.delete_diff_value, 'r')
self.assertEqual(options.nochange_diff_value, 'n')
self.assertEqual(options.change_column, 'c')
self.assertEqual(options.diff_mode, DiffMode.SideBySide)
self.assertEqual(options.sparse_mode, True)
self.assertEqual(options.default_comparator, cmp1)
self.assertEqual(options.data_type_comparators, {IntegerType(): cmp2, StringType(): cmp3})
self.assertEqual(options.column_name_comparators, {'value': cmp4})
self.assertNotEqual(options.diff_column, default.diff_column)
self.assertNotEqual(options.left_column_prefix, default.left_column_prefix)
self.assertNotEqual(options.right_column_prefix, default.right_column_prefix)
self.assertNotEqual(options.insert_diff_value, default.insert_diff_value)
self.assertNotEqual(options.change_diff_value, default.change_diff_value)
self.assertNotEqual(options.delete_diff_value, default.delete_diff_value)
self.assertNotEqual(options.nochange_diff_value, default.nochange_diff_value)
self.assertNotEqual(options.change_column, default.change_column)
self.assertNotEqual(options.diff_mode, default.diff_mode)
self.assertNotEqual(options.sparse_mode, default.sparse_mode)
without_change = options.without_change_column()
self.assertEqual(without_change.diff_column, 'd')
self.assertEqual(without_change.left_column_prefix, 'l')
self.assertEqual(without_change.right_column_prefix, 'r')
self.assertEqual(without_change.insert_diff_value, 'i')
self.assertEqual(without_change.change_diff_value, 'c')
self.assertEqual(without_change.delete_diff_value, 'r')
self.assertEqual(without_change.nochange_diff_value, 'n')
self.assertIsNone(without_change.change_column)
self.assertEqual(without_change.diff_mode, DiffMode.SideBySide)
self.assertEqual(without_change.sparse_mode, True)
def test_diff_with_epsilon_comparator(self):
# relative inclusive epsilon
options = DiffOptions() \
.with_column_name_comparator(DiffComparators.epsilon(0.1).as_relative().as_inclusive(), 'val')
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
expected = self.spark.createDataFrame(self.expected_diff) \
.withColumn("diff", when(col("id") == 1, "N").otherwise(col("diff"))) \
.collect()
self.assertEqual(expected, diff)
# relative exclusive epsilon
options = DiffOptions() \
.with_column_name_comparator(DiffComparators.epsilon(0.0909).as_relative().as_exclusive(), 'val')
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
# absolute inclusive epsilon
options = DiffOptions() \
.with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_inclusive(), 'val')
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(expected, diff)
# absolute exclusive epsilon
options = DiffOptions() \
.with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_exclusive(), 'val')
diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()
self.assertEqual(self.expected_diff, diff)
def test_diff_options_with_duplicate_comparators(self):
options = DiffOptions() \
.with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType()) \
.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')
with self.assertRaisesRegex(ValueError, "A comparator for data type date exists already."):
options.with_data_type_comparator(DiffComparators.default(), DateType())
with self.assertRaisesRegex(ValueError, "A comparator for data type int exists already."):
options.with_data_type_comparator(DiffComparators.default(), IntegerType())
with self.assertRaisesRegex(ValueError, "A comparator for data types date, int exists already."):
options.with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType())
with self.assertRaisesRegex(ValueError, "A comparator for column name col1 exists already."):
options.with_column_name_comparator(DiffComparators.default(), 'col1')
with self.assertRaisesRegex(ValueError, "A comparator for column name col2 exists already."):
options.with_column_name_comparator(DiffComparators.default(), 'col2')
with self.assertRaisesRegex(ValueError, "A comparator for column names col1, col2 exists already."):
options.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_histogram.py
================================================
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import skipIf
from spark_common import SparkTest
import gresearch.spark
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Historgam")
class HistogramTest(SparkTest):
@classmethod
def setUpClass(cls):
super(HistogramTest, cls).setUpClass()
cls.df = cls.spark.createDataFrame([
(1, 1),
(1, 2),
(1, 10),
(2, -3),
(2, 5),
(3, 8),
], ['id', 'value'])
def test_histogram_with_ints(self):
hist = self.df.histogram([-5, 0, 5], 'value', 'id').orderBy('id').collect()
self.assertEqual([
{'id': 1, '≤-5': 0, '≤0': 0, '≤5': 2, '>5': 1},
{'id': 2, '≤-5': 0, '≤0': 1, '≤5': 1, '>5': 0},
{'id': 3, '≤-5': 0, '≤0': 0, '≤5': 0, '>5': 1},
], [row.asDict() for row in hist])
def test_histogram_with_floats(self):
hist = self.df.histogram([-5.0, 0.0, 5.0], 'value', 'id').orderBy('id').collect()
self.assertEqual([
{'id': 1, '≤-5.0': 0, '≤0.0': 0, '≤5.0': 2, '>5.0': 1},
{'id': 2, '≤-5.0': 0, '≤0.0': 1, '≤5.0': 1, '>5.0': 0},
{'id': 3, '≤-5.0': 0, '≤0.0': 0, '≤5.0': 0, '>5.0': 1},
], [row.asDict() for row in hist])
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_job_description.py
================================================
# Copyright 2023 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import skipIf
from pyspark import TaskContext, SparkContext
from typing import Optional
from spark_common import SparkTest
from gresearch.spark import job_description, append_job_description
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by JobDescription")
class JobDescriptionTest(SparkTest):
def _assert_job_description(self, expected: Optional[str]):
def get_job_description_func(part):
def func(row):
return row.id, part, TaskContext.get().getLocalProperty("spark.job.description")
return func
descriptions = self.spark.range(3, numPartitions=3).rdd \
.mapPartitionsWithIndex(lambda part, it: map(get_job_description_func(part), it)) \
.collect()
self.assertEqual(
[(0, 0, expected), (1, 1, expected), (2, 2, expected)],
descriptions
)
def setUp(self) -> None:
SparkContext._active_spark_context.setJobDescription(None)
def test_with_job_description(self):
self._assert_job_description(None)
with job_description("job description"):
self._assert_job_description("job description")
with job_description("inner job description"):
self._assert_job_description("inner job description")
self._assert_job_description("job description")
with job_description("inner job description", True):
self._assert_job_description("job description")
self._assert_job_description("job description")
self._assert_job_description(None)
with job_description("other job description", True):
self._assert_job_description("other job description")
self._assert_job_description(None)
def test_append_job_description(self):
self._assert_job_description(None)
with append_job_description("job"):
self._assert_job_description("job")
with append_job_description("description"):
self._assert_job_description("job - description")
self._assert_job_description("job")
with append_job_description("description 2", " "):
self._assert_job_description("job description 2")
self._assert_job_description("job")
self._assert_job_description(None)
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_jvm.py
================================================
# Copyright 2024 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import skipIf, skipUnless
from pyspark.sql.functions import sum
from gresearch.spark import _get_jvm, \
dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, \
histogram, job_description, append_description
from gresearch.spark.diff import *
from gresearch.spark.parquet import *
from spark_common import SparkTest
EXPECTED_UNSUPPORTED_MESSAGE = "This feature is not supported for Spark Connect. Please use a classic Spark client. " \
"https://github.com/G-Research/spark-extension#spark-connect-server"
class PackageTest(SparkTest):
df = None
@classmethod
def setUpClass(cls):
super(PackageTest, cls).setUpClass()
cls.df = cls.spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"])
@skipIf(SparkTest.is_spark_connect, "Spark classic client tests")
def test_get_jvm_classic(self):
for obj in [self.spark, self.spark.sparkContext, self.df, self.spark.read]:
with self.subTest(type(obj).__name__):
self.assertIsNotNone(_get_jvm(obj))
with self.subTest("Unsupported"):
with self.assertRaises(RuntimeError) as e:
_get_jvm(object())
self.assertEqual(("Unsupported class: ", ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_get_jvm_connect(self):
for obj in [self.spark, self.df, self.spark.read]:
with self.subTest(type(obj).__name__):
with self.assertRaises(RuntimeError) as e:
_get_jvm(obj)
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
with self.subTest("Unsupported"):
with self.assertRaises(RuntimeError) as e:
_get_jvm(object())
self.assertEqual(("Unsupported class: ", ), e.exception.args)
@skipIf(SparkTest.is_spark_connect, "Spark classic client tests")
def test_get_jvm_check_java_pkg_is_installed(self):
from gresearch import spark
is_installed = spark._java_pkg_is_installed
try:
spark._java_pkg_is_installed = False
with self.assertRaises(RuntimeError) as e:
_get_jvm(self.spark)
self.assertEqual(("Java / Scala package not found! You need to add the Maven spark-extension package "
"to your PySpark environment: https://github.com/G-Research/spark-extension#python", ), e.exception.args)
finally:
spark._java_pkg_is_installed = is_installed
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_dotnet_ticks(self):
for label, func in {
'dotnet_ticks_to_timestamp': dotnet_ticks_to_timestamp,
'dotnet_ticks_to_unix_epoch': dotnet_ticks_to_unix_epoch,
'dotnet_ticks_to_unix_epoch_nanos': dotnet_ticks_to_unix_epoch_nanos,
'timestamp_to_dotnet_ticks': timestamp_to_dotnet_ticks,
'unix_epoch_to_dotnet_ticks': unix_epoch_to_dotnet_ticks,
'unix_epoch_nanos_to_dotnet_ticks': unix_epoch_nanos_to_dotnet_ticks,
}.items():
with self.subTest(label):
with self.assertRaises(RuntimeError) as e:
func("id")
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_histogram(self):
with self.assertRaises(RuntimeError) as e:
self.df.histogram([1, 10, 100], "bin", sum)
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_with_row_numbers(self):
with self.assertRaises(RuntimeError) as e:
self.df.with_row_numbers()
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_job_description(self):
with self.assertRaises(RuntimeError) as e:
with job_description("job description"):
pass
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
with self.assertRaises(RuntimeError) as e:
with append_description("job description"):
pass
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_create_temp_dir(self):
with self.assertRaises(RuntimeError) as e:
self.spark.create_temporary_dir("prefix")
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_install_pip_package(self):
with self.assertRaises(RuntimeError) as e:
self.spark.install_pip_package("pytest")
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_install_poetry_project(self):
with self.assertRaises(RuntimeError) as e:
self.spark.install_poetry_project("./poetry-project")
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_parquet(self):
for label, func in {
'parquet_metadata': lambda dr: dr.parquet_metadata("file.parquet"),
'parquet_schema': lambda dr: dr.parquet_schema("file.parquet"),
'parquet_blocks': lambda dr: dr.parquet_blocks("file.parquet"),
'parquet_block_columns': lambda dr: dr.parquet_block_columns("file.parquet"),
'parquet_partitions': lambda dr: dr.parquet_partitions("file.parquet"),
}.items():
with self.subTest(label):
with self.assertRaises(RuntimeError) as e:
func(self.spark.read)
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_package.py
================================================
# Copyright 2023 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from decimal import Decimal
from subprocess import CalledProcessError
from unittest import skipUnless, skipIf
from pyspark import __version__, SparkContext
from pyspark.sql import Row, SparkSession, SQLContext
from pyspark.sql.functions import col, count
from gresearch.spark import backticks, distinct_prefix_for, handle_configured_case_sensitivity, \
list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity, \
dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null
from spark_common import SparkTest
try:
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
has_connect = True
except ImportError:
has_connect = False
POETRY_PYTHON_ENV = "POETRY_PYTHON"
RICH_SOURCES_ENV = "RICH_SOURCES"
class PackageTest(SparkTest):
@classmethod
def setUpClass(cls):
super(PackageTest, cls).setUpClass()
cls.ticks = cls.spark.createDataFrame([
(1, 599266080000000000),
(2, 621355968000000000),
(3, 638155413748959308),
(4, 638155413748959309),
(5, 638155413748959310),
(6, 713589688368547758),
(7, 946723967999999999)
], ['id', 'tick'])
cls.timestamps = cls.spark.createDataFrame([
(1, datetime.datetime(1900, 1, 1, tzinfo=datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(2, datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(3, datetime.datetime(2023, 3, 27, 19, 16, 14, 895930, datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(4, datetime.datetime(2023, 3, 27, 19, 16, 14, 895930, datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(5, datetime.datetime(2023, 3, 27, 19, 16, 14, 895931, datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(6, datetime.datetime(2262, 4, 11, 23, 47, 16, 854775, datetime.timezone.utc).astimezone().replace(tzinfo=None)),
(7, datetime.datetime(3001, 1, 19, 7, 59, 59, 999999, datetime.timezone.utc).astimezone().replace(tzinfo=None))
], ['id', 'timestamp'])
cls.unix = cls.spark.createDataFrame([
(1, Decimal('-2208988800.000000000')),
(2, Decimal('0E-9')),
(3, Decimal('1679944574.895930800')),
(4, Decimal('1679944574.895930900')),
(5, Decimal('1679944574.895931000')),
(6, Decimal('9223372036.854775800')),
(7, Decimal('32536799999.999999900'))
], ['id', 'unix'])
cls.unix_nanos = cls.spark.createDataFrame([
(1, -2208988800000000000),
(2, 0),
(3, 1679944574895930800),
(4, 1679944574895930900),
(5, 1679944574895931000),
(6, 9223372036854775800),
(7, None)
], ['id', 'unix_nanos'])
cls.ticks_from_timestamp = cls.spark.createDataFrame([
(1, 599266080000000000),
(2, 621355968000000000),
(3, 638155413748959300),
(4, 638155413748959300),
(5, 638155413748959310),
(6, 713589688368547750),
(7, 946723967999999990)
], ['id', 'tick'])
cls.ticks_from_unix_nanos = cls.spark.createDataFrame([
(1, 599266080000000000),
(2, 621355968000000000),
(3, 638155413748959308),
(4, 638155413748959309),
(5, 638155413748959310),
(6, 713589688368547758),
(7, None)
], ['id', 'tick'])
def compare_dfs(self, expected, actual):
print('expected')
expected.show(truncate=False)
print('actual')
actual.show(truncate=False)
self.assertEqual(
[row.asDict() for row in actual.collect()],
[row.asDict() for row in expected.collect()]
)
def test_backticks(self):
self.assertEqual(backticks("column"), "column")
self.assertEqual(backticks("a.column"), "`a.column`")
self.assertEqual(backticks("`a.column`"), "`a.column`")
self.assertEqual(backticks("column", "a.field"), "column.`a.field`")
self.assertEqual(backticks("a.column", "a.field"), "`a.column`.`a.field`")
self.assertEqual(backticks("the.alias", "a.column", "a.field"), "`the.alias`.`a.column`.`a.field`")
def test_distinct_prefix_for(self):
self.assertEqual(distinct_prefix_for([]), "_")
self.assertEqual(distinct_prefix_for(["a"]), "_")
self.assertEqual(distinct_prefix_for(["abc"]), "_")
self.assertEqual(distinct_prefix_for(["a", "bc", "def"]), "_")
self.assertEqual(distinct_prefix_for(["_a"]), "__")
self.assertEqual(distinct_prefix_for(["_abc"]), "__")
self.assertEqual(distinct_prefix_for(["a", "_bc", "__def"]), "___")
def test_handle_configured_case_sensitivity(self):
case_sensitive = False
with self.subTest(case_sensitive=case_sensitive):
self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc')
self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'abc')
self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'abc')
case_sensitive = True
with self.subTest(case_sensitive=case_sensitive):
self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc')
self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'AbC')
self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'ABC')
def test_list_contains_case_sensitivity(self):
the_list = ['abc', 'Def', 'GhI', 'JKL']
self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=False), False)
self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=False), True)
self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=False), True)
self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=False), True)
self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=True), False)
self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=True), True)
self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=True), False)
self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=True), True)
def test_list_filter_case_sensitivity(self):
the_list = ['abc', 'Def', 'GhI', 'JKL']
self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=False), [])
self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['abc'])
self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['Def'])
self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['JKL'])
self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=True), [])
self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['abc'])
self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=True), [])
self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['JKL'])
def test_list_diff_case_sensitivity(self):
the_list = ['abc', 'Def', 'GhI', 'JKL']
self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=False), the_list)
self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['Def', 'GhI', 'JKL'])
self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['abc', 'GhI', 'JKL'])
self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['abc', 'Def', 'GhI'])
self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=True), the_list)
self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['Def', 'GhI', 'JKL'])
self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=True), the_list)
self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['abc', 'Def', 'GhI'])
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_timestamp(self):
for column in ["tick", self.ticks.tick]:
with self.subTest(column=column):
timestamps = self.ticks.withColumn("timestamp", dotnet_ticks_to_timestamp(column)).orderBy('id')
expected = self.ticks.join(self.timestamps, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_unix_epoch(self):
for column in ["tick", self.ticks.tick]:
with self.subTest(column=column):
timestamps = self.ticks.withColumn("unix", dotnet_ticks_to_unix_epoch(column)).orderBy('id')
expected = self.ticks.join(self.unix, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_unix_epoch_nanos(self):
self.maxDiff = None
for column in ["tick", self.ticks.tick]:
with self.subTest(column=column):
timestamps = self.ticks.withColumn("unix_nanos", dotnet_ticks_to_unix_epoch_nanos(column)).orderBy('id')
expected = self.ticks.join(self.unix_nanos, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_timestamp_to_dotnet_ticks(self):
if self.spark.version.startswith('3.0.'):
self.skipTest('timestamp_to_dotnet_ticks not supported by Spark 3.0')
for column in ["timestamp", self.timestamps.timestamp]:
with self.subTest(column=column):
timestamps = self.timestamps.withColumn("tick", timestamp_to_dotnet_ticks(column)).orderBy('id')
expected = self.timestamps.join(self.ticks_from_timestamp, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_unix_epoch_dotnet_ticks(self):
for column in ["unix", self.unix.unix]:
with self.subTest(column=column):
timestamps = self.unix.withColumn("tick", unix_epoch_to_dotnet_ticks(column)).orderBy('id')
expected = self.unix.join(self.ticks, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_unix_epoch_nanos_to_dotnet_ticks(self):
for column in ["unix_nanos", self.unix_nanos.unix_nanos]:
with self.subTest(column=column):
timestamps = self.unix_nanos.withColumn("tick", unix_epoch_nanos_to_dotnet_ticks(column)).orderBy('id')
expected = self.unix_nanos.join(self.ticks_from_unix_nanos, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
def test_count_null(self):
actual = self.unix_nanos.select(
count("id").alias("ids"),
count(col("unix_nanos")).alias("nanos"),
count_null("id").alias("null_ids"),
count_null(col("unix_nanos")).alias("null_nanos"),
).collect()
self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual)
def test_session(self):
self.assertIsNotNone(self.ticks.session())
self.assertIsInstance(self.ticks.session(), tuple(([SparkSession] + ([ConnectSparkSession] if has_connect else []))))
def test_session_or_ctx(self):
self.assertIsNotNone(self.ticks.session_or_ctx())
self.assertIsInstance(self.ticks.session_or_ctx(), tuple(([SparkSession, SQLContext] + ([ConnectSparkSession] if has_connect else []))))
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by create_temp_dir")
def test_create_temp_dir(self):
from pyspark import SparkFiles
dir = self.spark.create_temporary_dir("prefix")
self.assertTrue(dir.startswith(SparkFiles.getRootDirectory()))
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package(self):
self.spark.sparkContext.setLogLevel("INFO")
with self.assertRaises(ImportError):
# noinspection PyPackageRequirements
import emoji
emoji.emojize("this test is :thumbs_up:")
self.spark.install_pip_package("emoji", '--cache', '.cache/pypi')
# noinspection PyPackageRequirements
import emoji
actual = emoji.emojize("this test is :thumbs_up:")
expected = "this test is 👍"
self.assertEqual(expected, actual)
import pandas as pd
actual = self.spark.range(0, 10, 1, 10) \
.mapInPandas(lambda it: [pd.DataFrame.from_dict({"val": [emoji.emojize(":thumbs_up:")]})], "val string") \
.collect()
expected = [Row("👍")] * 10
self.assertEqual(expected, actual)
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_unknown_argument(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("--unknown", "argument")
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_package_not_found(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("pyspark-extension==abc")
@skipUnless(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_not_supported(self):
with self.assertRaises(NotImplementedError):
self.spark.install_pip_package("emoji")
@skipIf(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')
# provide an environment variable with path to the python binary of a virtual env that has poetry installed
@skipIf(POETRY_PYTHON_ENV not in os.environ, f'Environment variable {POETRY_PYTHON_ENV} pointing to '
f'virtual env python with poetry required')
@skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '
f'rich project sources required')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project(self):
self.spark.sparkContext.setLogLevel("INFO")
with self.assertRaises(ImportError):
# noinspection PyPackageRequirements
from rich.emoji import Emoji
thumbs_up = Emoji("thumbs_up")
rich_path = os.environ[RICH_SOURCES_ENV]
poetry_python = os.environ[POETRY_PYTHON_ENV]
self.spark.install_poetry_project(
rich_path,
poetry_python=poetry_python,
pip_args=['--cache', '.cache/pypi']
)
# noinspection PyPackageRequirements
from rich.emoji import Emoji
thumbs_up = Emoji("thumbs_up")
actual = thumbs_up.replace("this test is :thumbs_up:")
expected = "this test is 👍"
self.assertEqual(expected, actual)
import pandas as pd
actual = self.spark.range(0, 10, 1, 10) \
.mapInPandas(lambda it: [pd.DataFrame.from_dict({"val": [thumbs_up.replace(":thumbs_up:")]})], "val string") \
.collect()
expected = [Row("👍")] * 10
self.assertEqual(expected, actual)
@skipIf(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')
# provide an environment variable with path to the python binary of a virtual env that has poetry installed
@skipIf(POETRY_PYTHON_ENV not in os.environ, f'Environment variable {POETRY_PYTHON_ENV} pointing to '
f'virtual env python with poetry required')
@skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '
f'rich project sources required')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project_wrong_arguments(self):
rich_path = os.environ[RICH_SOURCES_ENV]
poetry_python = os.environ[POETRY_PYTHON_ENV]
with self.assertRaises(RuntimeError):
self.spark.install_poetry_project("non-existing-project", poetry_python=poetry_python)
with self.assertRaises(FileNotFoundError):
self.spark.install_poetry_project(rich_path, poetry_python="non-existing-python")
@skipUnless(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project_not_supported(self):
with self.assertRaises(NotImplementedError):
self.spark.install_poetry_project("./rich")
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_parquet.py
================================================
# Copyright 2023 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest import skipIf
from spark_common import SparkTest
import gresearch.spark.parquet
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Parquet")
class ParquetTest(SparkTest):
test_file = str((Path(__file__).parent.parent.parent / "src" / "test" / "files" / "test.parquet").resolve())
def test_parquet_metadata(self):
self.assertEqual(self.spark.read.parquet_metadata(self.test_file).count(), 2)
self.assertEqual(self.spark.read.parquet_metadata(self.test_file, self.test_file).count(), 2)
self.assertEqual(self.spark.read.parquet_metadata(self.test_file, parallelism=100).count(), 2)
self.assertEqual(self.spark.read.parquet_metadata(self.test_file, self.test_file, parallelism=100).count(), 2)
def test_parquet_schema(self):
self.assertEqual(self.spark.read.parquet_schema(self.test_file).count(), 4)
self.assertEqual(self.spark.read.parquet_schema(self.test_file, self.test_file).count(), 4)
self.assertEqual(self.spark.read.parquet_schema(self.test_file, parallelism=100).count(), 4)
self.assertEqual(self.spark.read.parquet_schema(self.test_file, self.test_file, parallelism=100).count(), 4)
def test_parquet_blocks(self):
self.assertEqual(self.spark.read.parquet_blocks(self.test_file).count(), 3)
self.assertEqual(self.spark.read.parquet_blocks(self.test_file, self.test_file).count(), 3)
self.assertEqual(self.spark.read.parquet_blocks(self.test_file, parallelism=100).count(), 3)
self.assertEqual(self.spark.read.parquet_blocks(self.test_file, self.test_file, parallelism=100).count(), 3)
def test_parquet_block_columns(self):
self.assertEqual(self.spark.read.parquet_block_columns(self.test_file).count(), 6)
self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, self.test_file).count(), 6)
self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, parallelism=100).count(), 6)
self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, self.test_file, parallelism=100).count(), 6)
def test_parquet_partitions(self):
self.assertEqual(self.spark.read.parquet_partitions(self.test_file).count(), 2)
self.assertEqual(self.spark.read.parquet_partitions(self.test_file, self.test_file).count(), 2)
self.assertEqual(self.spark.read.parquet_partitions(self.test_file, parallelism=100).count(), 2)
self.assertEqual(self.spark.read.parquet_partitions(self.test_file, self.test_file, parallelism=100).count(), 2)
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: python/test/test_row_number.py
================================================
# Copyright 2022 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import skipIf
from pyspark.storagelevel import StorageLevel
from spark_common import SparkTest
import gresearch.spark
@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by RowNumber")
class RowNumberTest(SparkTest):
@classmethod
def setUpClass(cls):
super(RowNumberTest, cls).setUpClass()
cls.df1 = cls.spark.createDataFrame([
(1, 'one'),
(2, 'two'),
(3, 'three'),
(4, 'four'),
], ['id', 'value'])
cls.expected1 = [
{'id': 1, 'value': 'one', 'row_number': 1},
{'id': 2, 'value': 'two', 'row_number': 2},
{'id': 3, 'value': 'three', 'row_number': 3},
{'id': 4, 'value': 'four', 'row_number': 4},
]
cls.expected1Desc = [
{'id': 1, 'value': 'one', 'row_number': 4},
{'id': 2, 'value': 'two', 'row_number': 3},
{'id': 3, 'value': 'three', 'row_number': 2},
{'id': 4, 'value': 'four', 'row_number': 1},
]
cls.df2 = cls.spark.createDataFrame([
(1, 'one'),
(2, 'TWO'),
(2, 'two'),
(3, 'three'),
], ['id', 'value'])
cls.expected2 = [
{'id': 1, 'value': 'one', 'row_number': 1},
{'id': 2, 'value': 'TWO', 'row_number': 2},
{'id': 2, 'value': 'two', 'row_number': 3},
{'id': 3, 'value': 'three', 'row_number': 4},
]
cls.expected2Desc = [
{'id': 1, 'value': 'one', 'row_number': 4},
{'id': 2, 'value': 'TWO', 'row_number': 3},
{'id': 2, 'value': 'two', 'row_number': 2},
{'id': 3, 'value': 'three', 'row_number': 1},
]
def test_row_numbers(self):
rows = self.df1.with_row_numbers().orderBy('id', 'value').collect()
self.assertEqual(self.expected1, [row.asDict() for row in rows])
def test_row_numbers_order_one_column(self):
for order in ['id', ['id'], self.df1.id, [self.df1.id]]:
with self.subTest(order=order):
rows = self.df1.with_row_numbers(order=order).orderBy('id', 'value').collect()
self.assertEqual(self.expected1, [row.asDict() for row in rows])
def test_row_numbers_order_two_columns(self):
for order in [['id', 'value'], [self.df2.id, self.df2.value]]:
with self.subTest(order=order):
rows = self.df2.with_row_numbers(order=order).orderBy('id', 'value').collect()
self.assertEqual(self.expected2, [row.asDict() for row in rows])
def test_row_numbers_order_not_asc_one_column(self):
for order in ['id', ['id'], self.df1.id, [self.df1.id]]:
with self.subTest(order=order):
rows = self.df1.with_row_numbers(order=order, ascending=False).orderBy('id', 'value').collect()
self.assertEqual(self.expected1Desc, [row.asDict() for row in rows])
def test_row_numbers_order_not_asc_two_columns(self):
for order in [['id', 'value'], [self.df2.id, self.df2.value]]:
with self.subTest(order=order):
rows = self.df2.with_row_numbers(order=order, ascending=False).orderBy('id', 'value').collect()
self.assertEqual(self.expected2Desc, [row.asDict() for row in rows])
def test_row_numbers_order_desc_one_column(self):
for order in [self.df1.id.desc(), [self.df1.id.desc()]]:
with self.subTest(order=order):
rows = self.df1.with_row_numbers(order=order).orderBy('id', 'value').collect()
self.assertEqual(self.expected1Desc, [row.asDict() for row in rows])
def test_row_numbers_order_desc_two_columns(self):
for order in [[self.df2.id.desc(), self.df2.value.desc()]]:
with self.subTest(order=order):
rows = self.df2.with_row_numbers(order=order).orderBy('id', 'value').collect()
self.assertEqual(self.expected2Desc, [row.asDict() for row in rows])
def test_row_numbers_unpersist(self):
for storage_level in [StorageLevel.MEMORY_AND_DISK, StorageLevel.MEMORY_ONLY, StorageLevel.DISK_ONLY]:
with self.subTest(storage_level=storage_level):
# make sure the cache is clear
jcm = self.spark._jsparkSession.sharedState().cacheManager()
jcm.clearCache()
self.assertTrue(jcm.isEmpty())
unpersist = self.spark.unpersist_handle()
self.df1.with_row_numbers(storage_level=storage_level, unpersist_handle=unpersist) \
.orderBy('id', 'value').collect()
# the cache should not be empty now
self.assertFalse(jcm.isEmpty())
unpersist(blocking=True)
# this should have removed the only DataFrame from the cache
self.assertTrue(jcm.isEmpty())
# calling unpersist again does not hurt, this time without blocking
unpersist()
def test_row_numbers_row_number_col_name(self):
rows = self.df1.with_row_numbers(row_number_column_name='row').orderBy('id', 'value').collect()
self.assertEqual([{'row' if k == 'row_number' else k: v for k, v in row.items()}
for row in self.expected1],
[row.asDict() for row in rows])
if __name__ == '__main__':
SparkTest.main(__file__)
================================================
FILE: release.sh
================================================
#!/bin/bash
#
# Copyright 2020 G-Research
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Script to prepare release, see RELEASE.md for details
set -euo pipefail
# check for clean git status (except for CHANGELOG.md and release.sh)
readarray -t git_status < <(git status -s --untracked-files=no 2>/dev/null | grep -v -e " CHANGELOG.md$" -e " release.sh$")
if [ ${#git_status[@]} -gt 0 ]
then
echo "There are pending git changes:"
for (( i=0; i<${#git_status[@]}; i++ )); do echo "${git_status[$i]}" ; done
exit 1
fi
# check for unreleased entry in CHANGELOG.md
readarray -t changes < <(grep -A 100 "^## \[UNRELEASED\] - YYYY-MM-DD" CHANGELOG.md | grep -B 100 --max-count=1 -E "^## \[[0-9.]+\]" | grep "^-")
if [ ${#changes[@]} -eq 0 ]
then
echo "Did not find any changes in CHANGELOG.md under '## [UNRELEASED] - YYYY-MM-DD'"
exit 1
fi
# check this is a SNAPSHOT versions
if ! grep -q ".*-SNAPSHOT" pom.xml
then
echo "Version in pom is not a SNAPSHOT version, cannot test all versions"
exit 1
fi
# check for existing cached SNAPSHOT jars
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g" -e "s/-SNAPSHOT//" -e "s/-[0-9.]+//g")
jars=$(find $HOME/.m2 $HOME/.ivy2 -name "*spark-extension_*-$version-*-SNAPSHOT.jar")
if [[ -n "$jars" ]]
then
echo "There are installed SNAPSHOT jars, these may interfere with release tests. These must be deleted first:"
echo "$jars" | tr '\n' ' '
echo
exit 1
fi
# testing all versions
rm -rf metastore_db/ spark-warehouse/
./set-version.sh 3.2.4 2.12.15; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh
./set-version.sh 3.3.4 2.12.15; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh
./set-version.sh 3.4.4 2.12.17; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh
./set-version.sh 3.5.3 2.12.18; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh
rm -rf python/dist
./set-version.sh 3.2.4 2.13.5; mvn clean deploy -Dsign; ./test-release.sh
./set-version.sh 3.3.4 2.13.8; mvn clean deploy -Dsign; ./test-release.sh
./set-version.sh 3.4.4 2.13.8; mvn clean deploy -Dsign; ./test-release.sh
./set-version.sh 3.5.3 2.13.8; mvn clean deploy -Dsign; ./test-release.sh
rm -rf metastore_db/ spark-warehouse/
# all SNAPSHOT versions build, test and complete the example, releasing
# revert pom.xml and python/setup.py changes
git checkout pom.xml python/setup.py
# get latest and release version
latest=$(grep --max-count=1 ".*" README.md | sed -E -e "s/\s*<[^>]+>//g" -e "s/-[0-9.]+//g")
version=$(grep --max-count=1 ".*" pom.xml | sed -E -e "s/\s*<[^>]+>//g" -e "s/-SNAPSHOT//" -e "s/-[0-9.]+//g")
echo "Releasing ${#changes[@]} changes as version $version:"
for (( i=0; i<${#changes[@]}; i++ )); do echo "${changes[$i]}" ; done
sed -i "s/## \[UNRELEASED\] - YYYY-MM-DD/## [$version] - $(date +%Y-%m-%d)/" CHANGELOG.md
sed -i -e "s/$latest-/$version-/g" -e "s/$latest\./$version./g" -e "s/v$latest/v$version/g" README.md PYSPARK-DEPS.md python/README.md
./set-version.sh $version
# commit changes to local repo
echo
echo "Committing release to local git"
git add pom.xml python/setup.py CHANGELOG.md README.md PYSPARK-DEPS.md python/README.md
git commit -m "Releasing $version"
git tag -a "v${version}" -m "Release v${version}"
echo "Please inspect git changes:"
git show HEAD
echo "Press to push to origin"
read
echo "Pushing release commit and tag to origin"
git push origin master "v${version}"
echo
# create release
echo "Creating release packages"
mkdir -p python/pyspark/jars/
./set-version.sh 3.2.4 2.12.15; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh
./set-version.sh 3.3.4 2.12.15; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh
./set-version.sh 3.4.4 2.12.17; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh
./set-version.sh 3.5.3 2.12.18; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh
./set-version.sh 3.2.4 2.13.5; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
./set-version.sh 3.3.4 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
./set-version.sh 3.4.4 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
./set-version.sh 3.5.3 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
# upload to test PyPi
pip install twine
twine check python/dist/*
python3 -m twine upload --repository testpypi python/dist/*
echo "Press to upload to PyPi"
read
# upload to PyPi
python3 -m twine upload python/dist/*
echo
git checkout pom.xml python/setup.py
./bump-version.sh
================================================
FILE: set-version.sh
================================================
#!/bin/bash
if [ $# -eq 1 ]
then
IFS=-
read version flavour <<< "$1"
echo "setting version=$version${flavour:+ with }$flavour"
sed -i -E \
-e "s%^( )[^-]+-([^-]+).*()$%\1$version-\2${flavour:+-}$flavour\3%" \
pom.xml
version=$(grep -m 1 version pom.xml | sed "s/\s*<[^>]*>\s*//g")
sed -i -E \
-e "s/(jar_version *= *).*/\1'$version'/" \
python/setup.py
elif [ $# -eq 2 ]
then
spark=$1
scala=$2
spark_compat=${spark%.*}
scala_compat=${scala%.*}
spark_major=${spark_compat%.*}
scala_major=${scala_compat%.*}
spark_minor=${spark_compat/*./}
scala_minor=${scala_compat/*./}
spark_patch=${spark/*./}
scala_patch=${scala/*./}
echo "setting spark=$spark and scala=$scala"
sed -i -E \
-e "s%^( )([^_]+)[_0-9.]+()$%\1\2_${scala_compat}\3%" \
-e "s%^( )([^-]+)-[^-]+(.*)$%\1\2-$spark_compat\3%" \
-e "s%^( ).+()$%\1${scala_major}\2%" \
-e "s%^( ).+()$%\1${scala_minor}\2%" \
-e "s%^( ).+()$%\1${scala_patch}\2%" \
-e "s%^( ).+()$%\1${spark_major}\2%" \
-e "s%^( ).+()$%\1${spark_minor}\2%" \
-e "s%^( ).+()$%\1${spark_patch}\2%" \
pom.xml
version=$(grep -m 1 version pom.xml | sed "s/\s*<[^>]*>\s*//g")
sed -i -E \
-e "s/(jar_version *= *).*/\1'$version'/" \
-e "s/(scala_version *= *).*/\1'$scala'/" \
python/setup.py
else
echo "Provide the Spark-Extension version (e.g. 2.5.0 or 2.5.0-SNAPSHOT), or the Spark and Scala version"
exit 1
fi
================================================
FILE: src/main/scala/uk/co/gresearch/package.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co
package object gresearch {
trait ConditionalCall[T] {
def call(f: T => T): T
def either[R](f: T => R): ConditionalCallOr[T, R]
}
trait ConditionalCallOr[T, R] {
def or(f: T => R): R
}
case class TrueCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = f(t)
override def either[R](f: T => R): ConditionalCallOr[T, R] = TrueCallOr[T, R](f(t))
}
case class FalseCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = t
override def either[R](f: T => R): ConditionalCallOr[T, R] = FalseCallOr[T, R](t)
}
case class TrueCallOr[T, R](r: R) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = r
}
case class FalseCallOr[T, R](t: T) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = f(t)
}
implicit class ExtendedAny[T](t: T) {
/**
* Allows to call a function on the decorated instance conditionally.
*
* This allows fluent code like
*
* {{{
* i.doThis()
* .doThat()
* .on(condition).call(function)
* .on(condition).either(function1).or(function2)
* .doMore()
* }}}
*
* rather than
*
* {{{
* val temp = i.doThis()
* .doThat()
* val temp2 = if (condition) function(temp) else temp
* temp2.doMore()
* }}}
*
* which either needs many temporary variables or duplicate code.
*
* @param condition
* condition
* @return
* the function result
*/
def on(condition: Boolean): ConditionalCall[T] = {
if (condition) TrueCall[T](t) else FalseCall[T](t)
}
/**
* Allows to call a function on the decorated instance conditionally. This is an alias for the `on` function.
*
* This allows fluent code like
*
* {{{
* i.doThis()
* .doThat()
* .when(condition).call(function)
* .when(condition).either(function1).or(function2)
* .doMore()
*
*
* rather than
*
* {{{
* val temp = i.doThis()
* .doThat()
* val temp2 = if (condition) function(temp) else temp
* temp2.doMore()
* }}}
*
* which either needs many temporary variables or duplicate code.
*
* @param condition
* condition
* @return
* the function result
*/
def when(condition: Boolean): ConditionalCall[T] = on(condition)
/**
* Executes the given function on the decorated instance.
*
* This allows writing fluent code like
*
* {{{
* i.doThis()
* .doThat()
* .call(function)
* .doMore()
* }}}
*
* rather than
*
* {{{
* function(
* i.doThis()
* .doThat()
* ).doMore()
* }}}
*
* where the effective sequence of operations is not clear.
*
* @param f
* function
* @return
* the function result
*/
def call[R](f: T => R): R = f(t)
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/BuildVersion.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import java.util.Properties
/**
* Provides versions from build environment.
*/
trait BuildVersion {
val propertyFileName = "spark-extension-build.properties"
lazy val props: Properties = {
val properties = new Properties
val in = Option(Thread.currentThread().getContextClassLoader.getResourceAsStream(propertyFileName))
if (in.isEmpty) {
throw new RuntimeException(s"Property file $propertyFileName not found in class path")
}
in.foreach(properties.load)
properties
}
lazy val VersionString: String = props.getProperty("project.version")
lazy val BuildSparkMajorVersion: Int = props.getProperty("spark.major.version").toInt
lazy val BuildSparkMinorVersion: Int = props.getProperty("spark.minor.version").toInt
lazy val BuildSparkPatchVersion: Int = props.getProperty("spark.patch.version").split("-").head.toInt
lazy val BuildSparkCompatVersionString: String = props.getProperty("spark.compat.version")
lazy val BuildScalaMajorVersion: Int = props.getProperty("scala.major.version").toInt
lazy val BuildScalaMinorVersion: Int = props.getProperty("scala.minor.version").toInt
lazy val BuildScalaPatchVersion: Int = props.getProperty("scala.patch.version").toInt
lazy val BuildScalaCompatVersionString: String = props.getProperty("scala.compat.version")
val BuildSparkVersion: (Int, Int, Int) = (BuildSparkMajorVersion, BuildSparkMinorVersion, BuildSparkPatchVersion)
val BuildSparkCompatVersion: (Int, Int) = (BuildSparkMajorVersion, BuildSparkMinorVersion)
val BuildScalaVersion: (Int, Int, Int) = (BuildScalaMajorVersion, BuildScalaMinorVersion, BuildScalaPatchVersion)
val BuildScalaCompatVersion: (Int, Int) = (BuildScalaMajorVersion, BuildScalaMinorVersion)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/Histogram.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.functions.{sum, when}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import uk.co.gresearch.ExtendedAny
import scala.collection.JavaConverters
object Histogram {
/**
* Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
* ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
* in thresholds, there will be a column named s"≤threshold". There will also be a final column called
* s">last_threshold", that counts the remaining values that exceed the last threshold.
*
* @param df
* dataset to compute histogram from
* @param thresholds
* sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn
* histogram is computed for values of this column
* @param aggregateColumns
* histogram is computed against these columns
* @tparam T
* type of histogram thresholds
* @return
* dataframe with aggregate and histogram columns
*/
def of[D, T](df: Dataset[D], thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = {
if (thresholds.isEmpty)
throw new IllegalArgumentException("Thresholds must not be empty")
val bins = if (thresholds.length == 1) Seq.empty else thresholds.sliding(2).toSeq
if (bins.exists(s => s.head == s.last))
throw new IllegalArgumentException(s"Thresholds must not contain duplicates: ${thresholds.mkString(",")}")
df.toDF()
.withColumn(s"≤${thresholds.head}", when(valueColumn <= thresholds.head, 1).otherwise(0))
.call(bins.foldLeft(_) { case (df, bin) =>
df.withColumn(s"≤${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))
})
.withColumn(s">${thresholds.last}", when(valueColumn > thresholds.last, 1).otherwise(0))
.groupBy(aggregateColumns: _*)
.agg(
Some(thresholds.head).map(t => sum(backticks(s"≤$t")).as(s"≤$t")).get,
thresholds.tail.map(t => sum(backticks(s"≤$t")).as(s"≤$t")) :+
sum(backticks(s">${thresholds.last}")).as(s">${thresholds.last}"): _*
)
}
/**
* Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
* ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
* in thresholds, there will be a column named s"≤threshold". There will also be a final column called
* s">last_threshold", that counts the remaining values that exceed the last threshold.
*
* @param df
* dataset to compute histogram from
* @param thresholds
* sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn
* histogram is computed for values of this column
* @param aggregateColumns
* histogram is computed against these columns
* @tparam T
* type of histogram thresholds
* @return
* dataframe with aggregate and histogram columns
*/
@scala.annotation.varargs
def of[D, T](
df: Dataset[D],
thresholds: java.util.List[T],
valueColumn: Column,
aggregateColumns: Column*
): DataFrame =
of(df, JavaConverters.iterableAsScalaIterable(thresholds).toSeq, valueColumn, aggregateColumns: _*)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions}
import org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum}
import org.apache.spark.storage.StorageLevel
case class RowNumbersFunc(
rowNumberColumnName: String = "row_number",
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
orderColumns: Seq[Column] = Seq.empty
) {
def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =
this.copy(rowNumberColumnName = rowNumberColumnName)
def withStorageLevel(storageLevel: StorageLevel): RowNumbersFunc =
this.copy(storageLevel = storageLevel)
def withUnpersistHandle(unpersistHandle: UnpersistHandle): RowNumbersFunc =
this.copy(unpersistHandle = unpersistHandle)
def withOrderColumns(orderColumns: Seq[Column]): RowNumbersFunc =
this.copy(orderColumns = orderColumns)
def of[D](df: Dataset[D]): DataFrame = {
if (
storageLevel.equals(
StorageLevel.NONE
) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)
) {
throw new IllegalArgumentException(s"Storage level $storageLevel not supported with Spark 3.5.0 and above.")
}
// define some column names that do not exist in ds
val prefix = distinctPrefixFor(df.columns)
val monoIdColumnName = prefix + "mono_id"
val partitionIdColumnName = prefix + "partition_id"
val localRowNumberColumnName = prefix + "local_row_number"
val maxLocalRowNumberColumnName = prefix + "max_local_row_number"
val cumRowNumbersColumnName = prefix + "cum_row_numbers"
val partitionOffsetColumnName = prefix + "partition_offset"
// if no order is given, we preserve existing order
val dfOrdered =
if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id())
else df.orderBy(orderColumns: _*)
val order = if (orderColumns.isEmpty) Seq(col(monoIdColumnName)) else orderColumns
// add partition ids and local row numbers
val localRowNumberWindow = Window.partitionBy(partitionIdColumnName).orderBy(order: _*)
val dfWithPartitionId = dfOrdered
.withColumn(partitionIdColumnName, spark_partition_id())
.persist(storageLevel)
unpersistHandle.setDataFrame(dfWithPartitionId)
val dfWithLocalRowNumbers = dfWithPartitionId
.withColumn(localRowNumberColumnName, functions.row_number().over(localRowNumberWindow))
// compute row offset for the partitions
val cumRowNumbersWindow = Window
.orderBy(partitionIdColumnName)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val partitionOffsets = dfWithLocalRowNumbers
.groupBy(partitionIdColumnName)
.agg(max(localRowNumberColumnName).alias(maxLocalRowNumberColumnName))
.withColumn(cumRowNumbersColumnName, sum(maxLocalRowNumberColumnName).over(cumRowNumbersWindow))
.select(
col(partitionIdColumnName) + 1 as partitionIdColumnName,
col(cumRowNumbersColumnName).as(partitionOffsetColumnName)
)
// compute global row number by adding local row number with partition offset
val partitionOffsetColumn = coalesce(col(partitionOffsetColumnName), lit(0))
dfWithLocalRowNumbers
.join(partitionOffsets, Seq(partitionIdColumnName), "left")
.withColumn(rowNumberColumnName, col(localRowNumberColumnName) + partitionOffsetColumn)
.drop(monoIdColumnName, partitionIdColumnName, localRowNumberColumnName, partitionOffsetColumnName)
}
}
object RowNumbers {
def default(): RowNumbersFunc = RowNumbersFunc()
def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =
default().withRowNumberColumnName(rowNumberColumnName)
def withStorageLevel(storageLevel: StorageLevel): RowNumbersFunc =
default().withStorageLevel(storageLevel)
def withUnpersistHandle(unpersistHandle: UnpersistHandle): RowNumbersFunc =
default().withUnpersistHandle(unpersistHandle)
@scala.annotation.varargs
def withOrderColumns(orderColumns: Column*): RowNumbersFunc =
default().withOrderColumns(orderColumns)
def of[D](ds: Dataset[D]): DataFrame = default().of(ds)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/SparkVersion.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.SPARK_VERSION_SHORT
/**
* Provides versions form runtime environment.
*/
trait SparkVersion {
private def SparkVersionSeq: Seq[Int] = SPARK_VERSION_SHORT.split('.').toSeq.map(_.toInt)
def SparkMajorVersion: Int = SparkVersionSeq.head
def SparkMinorVersion: Int = SparkVersionSeq(1)
def SparkPatchVersion: Int = SparkVersionSeq(2)
def SparkVersion: (Int, Int, Int) = (SparkMajorVersion, SparkMinorVersion, SparkPatchVersion)
def SparkCompatVersion: (Int, Int) = (SparkMajorVersion, SparkMinorVersion)
def SparkCompatVersionString: String = SparkVersionSeq.slice(0, 2).mkString(".")
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.DataFrame
/**
* Handle to call `DataFrame.unpersist` on a `DataFrame` that is not known to the caller. The [[RowNumbers.of]]
* constructs a `DataFrame` that is based ony an intermediate cached `DataFrame`, for witch `unpersist` must be called.
* A provided [[UnpersistHandle]] allows to do that in user code.
*/
class UnpersistHandle {
var df: Option[DataFrame] = None
private[spark] def setDataFrame(dataframe: DataFrame): DataFrame = {
if (df.isDefined) throw new IllegalStateException("DataFrame has been set already, it cannot be reused.")
this.df = Some(dataframe)
dataframe
}
def apply(): Unit = {
this.df.getOrElse(throw new IllegalStateException("DataFrame has to be set first")).unpersist()
}
def apply(blocking: Boolean): Unit = {
this.df.getOrElse(throw new IllegalStateException("DataFrame has to be set first")).unpersist(blocking)
}
}
case class SilentUnpersistHandle() extends UnpersistHandle {
override def apply(): Unit = {
this.df.foreach(_.unpersist())
}
override def apply(blocking: Boolean): Unit = {
this.df.foreach(_.unpersist(blocking))
}
}
case class NoopUnpersistHandle() extends UnpersistHandle {
override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe
override def apply(): Unit = {}
override def apply(blocking: Boolean): Unit = {}
}
object UnpersistHandle {
val Noop: NoopUnpersistHandle = NoopUnpersistHandle()
def apply(): UnpersistHandle = new UnpersistHandle()
def withUnpersist[T](blocking: Boolean = false)(func: UnpersistHandle => T): T = {
val handle = SilentUnpersistHandle()
try {
func(handle)
} finally {
handle(blocking)
}
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/App.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import scopt.OptionParser
import uk.co.gresearch._
object App {
// define available options
case class Options(
master: Option[String] = None,
appName: Option[String] = None,
hive: Boolean = false,
leftPath: Option[String] = None,
rightPath: Option[String] = None,
outputPath: Option[String] = None,
leftFormat: Option[String] = None,
rightFormat: Option[String] = None,
outputFormat: Option[String] = None,
leftSchema: Option[String] = None,
rightSchema: Option[String] = None,
leftOptions: Map[String, String] = Map.empty,
rightOptions: Map[String, String] = Map.empty,
outputOptions: Map[String, String] = Map.empty,
ids: Seq[String] = Seq.empty,
ignore: Seq[String] = Seq.empty,
saveMode: SaveMode = SaveMode.ErrorIfExists,
filter: Set[String] = Set.empty,
statistics: Boolean = false,
diffOptions: DiffOptions = DiffOptions.default
)
// read options from args
val programName = s"spark-extension_${spark.BuildScalaCompatVersionString}-${spark.VersionString}.jar"
val scop = s"com.github.scopt:scopt_${spark.BuildScalaCompatVersionString}:4.1.0"
val sparkSubmit = s"spark-submit --packages $scop $programName"
val parser: OptionParser[Options] = new scopt.OptionParser[Options](programName) {
head(s"Spark Diff app (${spark.VersionString})")
head()
arg[String]("left")
.required()
.valueName("")
.action((x, c) => c.copy(leftPath = Some(x)))
.text("file path (requires format option) or table name to read left dataframe")
arg[String]("right")
.required()
.valueName("")
.action((x, c) => c.copy(rightPath = Some(x)))
.text("file path (requires format option) or table name to read right dataframe")
arg[String]("diff")
.required()
.valueName("")
.action((x, c) => c.copy(outputPath = Some(x)))
.text("file path (requires format option) or table name to write diff dataframe")
note("")
note("Examples:")
note("")
note(" - Diff CSV files 'left.csv' and 'right.csv' and write result into CSV file 'diff.csv':")
note(s" $sparkSubmit --format csv left.csv right.csv diff.csv")
note("")
note(" - Diff CSV file 'left.csv' and Parquet file 'right.parquet' with id column 'id',")
note(" and write result into Hive table 'diff':")
note(s" $sparkSubmit --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff")
note("")
note("Spark session")
opt[String]("master")
.valueName("")
.action((x, c) => c.copy(master = Some(x)))
.text("Spark master (local, yarn, ...), not needed with spark-submit")
opt[String]("app-name")
.valueName("")
.action((x, c) => c.copy(appName = Some(x)))
.text("Spark application name")
.withFallback(() => "Diff App")
opt[Unit]("hive")
.optional()
.action((_, c) => c.copy(hive = true))
.text(s"enable Hive support to read from and write to Hive tables")
note("")
note("Input and output")
opt[String]('f', "format")
.valueName("")
.action((x, c) =>
c.copy(
leftFormat = c.leftFormat.orElse(Some(x)),
rightFormat = c.rightFormat.orElse(Some(x)),
outputFormat = c.outputFormat.orElse(Some(x))
)
)
.text("input and output file format (csv, json, parquet, ...)")
opt[String]("left-format")
.valueName("")
.action((x, c) => c.copy(leftFormat = Some(x)))
.text("left input file format (csv, json, parquet, ...)")
opt[String]("right-format")
.valueName("")
.action((x, c) => c.copy(rightFormat = Some(x)))
.text("right input file format (csv, json, parquet, ...)")
opt[String]("output-format")
.valueName("")
.action((x, c) => c.copy(outputFormat = Some(x)))
.text("output file format (csv, json, parquet, ...)")
note("")
opt[String]('s', "schema")
.valueName("")
.action((x, c) =>
c.copy(
leftSchema = c.leftSchema.orElse(Some(x)),
rightSchema = c.rightSchema.orElse(Some(x))
)
)
.text("input schema")
opt[String]("left-schema")
.valueName("")
.action((x, c) => c.copy(leftSchema = Some(x)))
.text("left input schema")
opt[String]("right-schema")
.valueName("")
.action((x, c) => c.copy(rightSchema = Some(x)))
.text("right input schema")
note("")
opt[(String, String)]("left-option")
.unbounded()
.optional()
.keyValueName("key", "val")
.action((x, c) => c.copy(leftOptions = c.leftOptions + (x._1 -> x._2)))
.text("left input option")
opt[(String, String)]("right-option")
.unbounded()
.optional()
.keyValueName("key", "val")
.action((x, c) => c.copy(rightOptions = c.rightOptions + (x._1 -> x._2)))
.text("right input option")
opt[(String, String)]("output-option")
.unbounded()
.optional()
.keyValueName("key", "val")
.action((x, c) => c.copy(outputOptions = c.outputOptions + (x._1 -> x._2)))
.text("output option")
note("")
opt[String]("id")
.unbounded()
.valueName("")
.action((x, c) => c.copy(ids = c.ids :+ x))
.text(s"id column name")
opt[String]("ignore")
.unbounded()
.valueName("")
.action((x, c) => c.copy(ignore = c.ignore :+ x))
.text(s"ignore column name")
opt[String]("save-mode")
.optional()
.valueName("")
.action((x, c) => c.copy(saveMode = SaveMode.valueOf(x)))
.text(s"save mode for writing output (${SaveMode.values().mkString(", ")}, default ${Options().saveMode})")
opt[String]("filter")
.unbounded()
.optional()
.valueName("")
.action((x, c) => c.copy(filter = c.filter + x))
.text(
s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)"
)
opt[Unit]("statistics")
.optional()
.action((_, c) => c.copy(statistics = true))
.text(s"Only output statistics on how many rows exist per diff action (see 'Diffing options' section)")
note("")
note("Diffing options")
opt[String]("diff-column")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(diffColumn = x)))
.text(s"column name for diff column (default '${DiffOptions.default.diffColumn}')")
opt[String]("left-prefix")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(leftColumnPrefix = x)))
.text(s"prefix for left column names (default '${DiffOptions.default.leftColumnPrefix}')")
opt[String]("right-prefix")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(rightColumnPrefix = x)))
.text(s"prefix for right column names (default '${DiffOptions.default.rightColumnPrefix}')")
opt[String]("insert-value")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(insertDiffValue = x)))
.text(s"value for insertion (default '${DiffOptions.default.insertDiffValue}')")
opt[String]("change-value")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(changeDiffValue = x)))
.text(s"value for change (default '${DiffOptions.default.changeDiffValue}')")
opt[String]("delete-value")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(deleteDiffValue = x)))
.text(s"value for deletion (default '${DiffOptions.default.deleteDiffValue}')")
opt[String]("no-change-value")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(nochangeDiffValue = x)))
.text(s"value for no change (default '${DiffOptions.default.nochangeDiffValue}')")
opt[String]("change-column")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(changeColumn = Some(x))))
.text(s"column name for change column (default is no such column)")
opt[String]("diff-mode")
.optional()
.valueName("")
.action((x, c) => c.copy(diffOptions = c.diffOptions.copy(diffMode = DiffMode.withName(x))))
.text(s"diff mode (${DiffMode.values.mkString(", ")}, default ${Options().diffOptions.diffMode})")
opt[Unit]("sparse")
.optional()
.action((_, c) => c.copy(diffOptions = c.diffOptions.copy(sparseMode = true)))
.text(s"enable sparse diff")
note("")
note("General")
help("help").text("prints this usage text")
}
def read(
spark: SparkSession,
format: Option[String],
path: String,
schema: Option[String],
options: Map[String, String]
): DataFrame =
spark.read
.when(format.isDefined)
.call(_.format(format.get))
.options(options)
.when(schema.isDefined)
.call(_.schema(schema.get))
.when(format.isDefined)
.either(_.load(path))
.or(_.table(path))
def write(
df: DataFrame,
format: Option[String],
path: String,
options: Map[String, String],
saveMode: SaveMode,
filter: Set[String],
saveStats: Boolean,
diffOptions: DiffOptions
): Unit =
df.when(filter.nonEmpty)
.call(_.where(col(diffOptions.diffColumn).isInCollection(filter)))
.when(saveStats)
.call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn))
.write
.when(format.isDefined)
.call(_.format(format.get))
.options(options)
.mode(saveMode)
.when(format.isDefined)
.either(_.save(path))
.or(_.saveAsTable(path))
def main(args: Array[String]): Unit = {
// parse options
val options = parser.parse(args, Options()) match {
case Some(options) => options
case None => sys.exit(1)
}
val unknownFilters = options.filter.filter(filter => !options.diffOptions.diffValues.contains(filter))
if (unknownFilters.nonEmpty) {
throw new RuntimeException(
s"Filter ${unknownFilters.mkString("'", "', '", "'")} not allowed, " +
s"these are the configured diff values: ${options.diffOptions.diffValues.mkString("'", "', '", "'")}"
)
}
// create spark session
val spark = SparkSession
.builder()
.appName(options.appName.get)
.when(options.hive)
.call(_.enableHiveSupport())
.when(options.master.isDefined)
.call(_.master(options.master.get))
.getOrCreate()
// read and write
val left = read(spark, options.leftFormat, options.leftPath.get, options.leftSchema, options.leftOptions)
val right = read(spark, options.rightFormat, options.rightPath.get, options.rightSchema, options.rightOptions)
val diff = left.diff(right, options.diffOptions, options.ids, options.ignore)
write(
diff,
options.outputFormat,
options.outputPath.get,
options.outputOptions,
options.saveMode,
options.filter,
options.statistics,
options.diffOptions
)
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType}
import uk.co.gresearch.spark.diff.comparator.DiffComparator
import uk.co.gresearch.spark.{backticks, distinctPrefixFor}
import scala.collection.JavaConverters
/**
* Differ class to diff two Datasets. See Differ.of(…) for details.
* @param options
* options for the diffing process
*/
class Differ(options: DiffOptions) {
private[diff] def checkSchema[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Unit = {
require(
left.columns.length == left.columns.toSet.size &&
right.columns.length == right.columns.toSet.size,
"The datasets have duplicate columns.\n" +
s"Left column names: ${left.columns.mkString(", ")}\n" +
s"Right column names: ${right.columns.mkString(", ")}"
)
val leftNonIgnored = left.columns.diffCaseSensitivity(ignoreColumns)
val rightNonIgnored = right.columns.diffCaseSensitivity(ignoreColumns)
val exceptIgnoredColumnsMsg = if (ignoreColumns.nonEmpty) " except ignored columns" else ""
require(
leftNonIgnored.length == rightNonIgnored.length,
"The number of columns doesn't match.\n" +
s"Left column names$exceptIgnoredColumnsMsg (${leftNonIgnored.length}): ${leftNonIgnored.mkString(", ")}\n" +
s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}"
)
require(leftNonIgnored.length > 0, s"The schema$exceptIgnoredColumnsMsg must not be empty")
// column types must match but we ignore the nullability of columns
val leftFields = left.schema.fields
.filter(f => !ignoreColumns.containsCaseSensitivity(f.name))
.map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val rightFields = right.schema.fields
.filter(f => !ignoreColumns.containsCaseSensitivity(f.name))
.map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val leftExtraSchema = leftFields.diff(rightFields)
val rightExtraSchema = rightFields.diff(leftFields)
require(
leftExtraSchema.isEmpty && rightExtraSchema.isEmpty,
"The datasets do not have the same schema.\n" +
s"Left extra columns: ${leftExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}\n" +
s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}"
)
val columns = leftNonIgnored
val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns
val nonPkColumns = columns.diffCaseSensitivity(pkColumns)
val missingIdColumns = pkColumns.diffCaseSensitivity(columns)
require(
missingIdColumns.isEmpty,
s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}"
)
val missingIgnoreColumns = ignoreColumns.diffCaseSensitivity(left.columns).diffCaseSensitivity(right.columns)
require(
missingIgnoreColumns.isEmpty,
s"Some ignore columns do not exist: ${missingIgnoreColumns.mkString(", ")} " +
s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}"
)
require(
!pkColumns.containsCaseSensitivity(options.diffColumn),
s"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(", ")}"
)
require(
options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)),
s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}"
)
val diffValueColumns = getDiffValueColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1)
if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) {
require(
!diffValueColumns.containsCaseSensitivity(options.diffColumn),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the diff column name '${options.diffColumn}': " +
s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}"
)
require(
options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the change column name '${options.changeColumn.get}': " +
s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}"
)
} else {
require(
!diffValueColumns.containsCaseSensitivity(options.diffColumn),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the diff column name '${options.diffColumn}': " +
s"${nonPkColumns.mkString(", ")}"
)
require(
options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the change column name '${options.changeColumn.orNull}': " +
s"${nonPkColumns.mkString(", ")}"
)
require(
diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce any id column name '${pkColumns.mkString("', '")}': " +
s"${nonPkColumns.mkString(", ")}"
)
}
}
private def getChangeColumn(
existsColumnName: String,
valueColumnsWithComparator: Seq[(String, DiffComparator)],
left: Dataset[_],
right: Dataset[_]
): Option[Column] = {
options.changeColumn
.map(changeColumn =>
when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null))
.otherwise(
Some(valueColumnsWithComparator)
.filter(_.nonEmpty)
.map(columns =>
concat(
columns
.map { case (c, cmp) =>
when(cmp.equiv(left(backticks(c)), right(backticks(c))), array()).otherwise(array(lit(c)))
}: _*
)
)
.getOrElse(
array().cast(ArrayType(StringType, containsNull = false))
)
)
.as(changeColumn)
)
}
private[diff] def getDiffIdColumns[T, U](
pkColumns: Seq[String],
left: Dataset[T],
right: Dataset[U],
): Seq[(String, Column)] = {
pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c))
}
private[diff] def getDiffValueColumns[T, U](
pkColumns: Seq[String],
valueColumns: Seq[String],
left: Dataset[T],
right: Dataset[U],
ignoreColumns: Seq[String]
): Seq[(String, Column)] = {
val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns)
val rightValueColumns = right.columns.filterIsInCaseSensitivity(valueColumns)
val leftNonPkColumns = left.columns.diffCaseSensitivity(pkColumns)
val rightNonPkColumns = right.columns.diffCaseSensitivity(pkColumns)
val leftIgnoredColumns = left.columns.filterIsInCaseSensitivity(ignoreColumns)
val rightIgnoredColumns = right.columns.filterIsInCaseSensitivity(ignoreColumns)
val (leftValues, rightValues) = if (options.sparseMode) {
(
leftNonPkColumns
.map(c =>
(
handleConfiguredCaseSensitivity(c),
c -> when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c)))
)
)
.toMap,
rightNonPkColumns
.map(c =>
(
handleConfiguredCaseSensitivity(c),
c -> when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c)))
)
)
.toMap
)
} else {
(
leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> left(backticks(c)))).toMap,
rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> right(backticks(c)))).toMap,
)
}
def alias(prefix: Option[String], values: Map[String, (String, Column)])(name: String): (String, Column) = {
values(handleConfiguredCaseSensitivity(name)) match {
case (name, column) =>
val alias = prefix.map(p => s"${p}_$name").getOrElse(name)
alias -> column.as(alias)
}
}
def aliasLeft(name: String): (String, Column) = alias(Some(options.leftColumnPrefix), leftValues)(name)
def aliasRight(name: String): (String, Column) = alias(Some(options.rightColumnPrefix), rightValues)(name)
val prefixedLeftIgnoredColumns = leftIgnoredColumns.map(c => aliasLeft(c))
val prefixedRightIgnoredColumns = rightIgnoredColumns.map(c => aliasRight(c))
options.diffMode match {
case DiffMode.ColumnByColumn =>
valueColumns.flatMap(c =>
Seq(
aliasLeft(c),
aliasRight(c)
)
) ++ ignoreColumns.flatMap(c =>
(if (leftIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasLeft(c)) else Seq.empty) ++
(if (rightIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasRight(c)) else Seq.empty)
)
case DiffMode.SideBySide =>
leftValueColumns.toSeq.map(c => aliasLeft(c)) ++ prefixedLeftIgnoredColumns ++
rightValueColumns.toSeq.map(c => aliasRight(c)) ++ prefixedRightIgnoredColumns
case DiffMode.LeftSide | DiffMode.RightSide =>
// in left-side / right-side mode, we do not prefix columns
(
if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues))
else valueColumns.map(alias(None, rightValues))
) ++ (
if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues))
else rightIgnoredColumns.map(alias(None, rightValues))
)
}
}
private[diff] def getDiffColumns[T, U](
pkColumns: Seq[String],
valueColumns: Seq[String],
left: Dataset[T],
right: Dataset[U],
ignoreColumns: Seq[String]
): Seq[(String, Column)] = {
getDiffIdColumns(pkColumns, left, right) ++ getDiffValueColumns(pkColumns, valueColumns, left, right, ignoreColumns)
}
private def doDiff[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String] = Seq.empty
): DataFrame = {
checkSchema(left, right, idColumns, ignoreColumns)
val columns = left.columns.diffCaseSensitivity(ignoreColumns).toList
val pkColumns = if (idColumns.isEmpty) columns else idColumns
val valueColumns = columns.diffCaseSensitivity(pkColumns)
val valueStructFields = left.schema.fields.map(f => f.name -> f).toMap
val valueColumnsWithComparator = valueColumns.map(c => c -> options.comparatorFor(valueStructFields(c)))
val existsColumnName = distinctPrefixFor(left.columns) + "exists"
val leftWithExists = left.withColumn(existsColumnName, lit(1))
val rightWithExists = right.withColumn(existsColumnName, lit(1))
val joinCondition =
pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _)
val unChanged = valueColumnsWithComparator
.map { case (c, cmp) =>
cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c)))
}
.reduceOption(_ && _)
val changeCondition = not(unChanged.getOrElse(lit(true)))
val diffActionColumn =
when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue))
.when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue))
.when(changeCondition, lit(options.changeDiffValue))
.otherwise(lit(options.nochangeDiffValue))
.as(options.diffColumn)
val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2)
val changeColumn = getChangeColumn(existsColumnName, valueColumnsWithComparator, leftWithExists, rightWithExists)
// turn this column into a sequence of one or none column so we can easily concat it below with diffActionColumn and diffColumns
.map(Seq(_))
.getOrElse(Seq.empty[Column])
leftWithExists
.join(rightWithExists, joinCondition, "fullouter")
.select((diffActionColumn +: changeColumn) ++ diffColumns: _*)
}
/**
* Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must
* contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as
* columns are compared based on the name, not the the position.
*
* Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* differ.diff(df1, df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* differ.diff(df1, df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
@scala.annotation.varargs
def diff[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =
doDiff(left, right, idColumns)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
* Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
* order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
* position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* differ.diff(df1, df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* differ.diff(df1, df2, Seq("id")).show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame =
doDiff(left, right, idColumns, ignoreColumns)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
* Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
* order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
* position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* differ.diff(df1, df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* differ.diff(df1, df2, Seq("id")).show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
def diff[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): DataFrame = {
diff(
left,
right,
JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
)
}
/**
* Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `diff(Dataset[T], Dataset[U], String*)`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
// no @scala.annotation.varargs here as implicit arguments are explicit in Java
// this signature is redundant to the other diffAs method in Java
def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit
diffEncoder: Encoder[V]
): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit
diffEncoder: Encoder[V]
): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `diff(Dataset[T], Dataset[T], String*)`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
@scala.annotation.varargs
def diffAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns, Seq.empty)
}
/**
* Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def diffAs[T, U, V](
left: Dataset[T],
right: Dataset[U],
diffEncoder: Encoder[V],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[V] = {
val nonIdColumns =
if (idColumns.isEmpty) Seq.empty
else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq
val encColumns = diffEncoder.schema.fields.map(_.name)
val diffColumns =
Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1)
val extraColumns = encColumns.diffCaseSensitivity(diffColumns)
require(
extraColumns.isEmpty,
s"Diff encoder's columns must be part of the diff result schema, " +
s"these columns are unexpected: ${extraColumns.mkString(", ")}"
)
diff(left, right, idColumns, ignoreColumns).as[V](diffEncoder)
}
/**
* Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def diffAs[T, U, V](
left: Dataset[T],
right: Dataset[U],
diffEncoder: Encoder[V],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): Dataset[V] = {
diffAs(
left,
right,
diffEncoder,
JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
)
}
/**
* Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type
* `(String, T, T)`.
*
* See `diff(Dataset[T], Dataset[T], String*)`.
*/
@scala.annotation.varargs
def diffWith[T](left: Dataset[T], right: Dataset[T], idColumns: String*): Dataset[(String, T, T)] = {
val df = diff(left, right, idColumns: _*)
diffWith(df, idColumns: _*)(left.encoder, right.encoder)
}
/**
* Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of
* type `(String, T, U)`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
def diffWith[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[(String, T, U)] = {
val df = diff(left, right, idColumns, ignoreColumns)
diffWith(df, idColumns: _*)(left.encoder, right.encoder)
}
/**
* Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of
* type `(String, T, U)`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
def diffWith[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): Dataset[(String, T, U)] = {
diffWith(
left,
right,
JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
)
}
private def columnsOfSide(df: DataFrame, idColumns: Seq[String], sidePrefix: String): Seq[Column] = {
val prefix = sidePrefix + "_"
df.columns
.filter(c => idColumns.contains(c) || c.startsWith(sidePrefix))
.map(c => if (idColumns.contains(c)) col(c) else col(c).as(c.replace(prefix, "")))
}
private def diffWith[T: Encoder, U: Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = {
val leftColumns = columnsOfSide(diff, idColumns, options.leftColumnPrefix)
val rightColumns = columnsOfSide(diff, idColumns, options.rightColumnPrefix)
val diffColumn = col(options.diffColumn).as("_1")
val leftStruct = when(col(options.diffColumn) === options.insertDiffValue, lit(null))
.otherwise(struct(leftColumns: _*))
.as("_2")
val rightStruct = when(col(options.diffColumn) === options.deleteDiffValue, lit(null))
.otherwise(struct(rightColumns: _*))
.as("_3")
val encoder: Encoder[(String, T, U)] = Encoders.tuple(
Encoders.STRING,
implicitly[Encoder[T]],
implicitly[Encoder[U]]
)
diff.select(diffColumn, leftStruct, rightStruct).as(encoder)
}
}
/**
* Diffing singleton with default diffing options.
*/
object Diff {
val default = new Differ(DiffOptions.default)
/**
* Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must
* contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as
* columns are compared based on the name, not the the position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* Diff.of(df1, df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* Diff.of(df1, df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
@scala.annotation.varargs
def of[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =
default.diff(left, right, idColumns: _*)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
* Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
* order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
* position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* Diff.of(df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* Diff.of(df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
def of[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String] = Seq.empty
): DataFrame =
default.diff(left, right, idColumns, ignoreColumns)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
* Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
* order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
* position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
* that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
* the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
* If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
* changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* Diff.of(df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* Diff.of(df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
* id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
def of[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): DataFrame =
default.diff(left, right, idColumns, ignoreColumns)
/**
* Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
// no @scala.annotation.varargs here as implicit arguments are explicit in Java
// this signature is redundant to the other ofAs method in Java
def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit
diffEncoder: Encoder[V]
): Dataset[V] =
default.diffAs(left, right, idColumns: _*)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def ofAs[T, U, V](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String] = Seq.empty
)(implicit diffEncoder: Encoder[V]): Dataset[V] =
default.diffAs(left, right, idColumns, ignoreColumns)
/**
* Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
@scala.annotation.varargs
def ofAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns: _*)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def ofAs[T, U, V](
left: Dataset[T],
right: Dataset[U],
diffEncoder: Encoder[V],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def ofAs[T, U, V](
left: Dataset[T],
right: Dataset[U],
diffEncoder: Encoder[V],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
/**
* Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type
* `(String, T, T)`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*/
@scala.annotation.varargs
def ofWith[T](left: Dataset[T], right: Dataset[T], idColumns: String*): Dataset[(String, T, T)] =
default.diffWith(left, right, idColumns: _*)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples
* of type `(String, T, U)`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
def ofWith[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[(String, T, U)] =
default.diffWith(left, right, idColumns, ignoreColumns)
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples
* of type `(String, T, U)`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
def ofWith[T, U](
left: Dataset[T],
right: Dataset[U],
idColumns: java.util.List[String],
ignoreColumns: java.util.List[String]
): Dataset[(String, T, U)] =
default.diffWith(left, right, idColumns, ignoreColumns)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types.DataType
import uk.co.gresearch.spark.diff.comparator._
import java.time.Duration
object DiffComparators {
/**
* The default comparator used in [[DiffOptions.default.defaultComparator]].
*/
def default(): DiffComparator = DefaultDiffComparator
/**
* A comparator equivalent to `Column <=> Column`. Null values are considered equal.
*/
def nullSafeEqual(): DiffComparator = NullSafeEqualDiffComparator
/**
* Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. The implicit [[Encoder]] of
* type [[T]] determines the input data type of the comparator. Only columns of that type can be compared.
*/
def equiv[T: Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv)
/**
* Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. Only columns of the given
* data type `inputType` can be compared.
*/
def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] =
EquivDiffComparator(equiv, inputType)
/**
* Return a comparator that uses the given [[math.Equiv]] to compare values of any type.
*/
def equiv(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivDiffComparator(equiv)
/**
* This comparator considers values equal when they are less than `epsilon` apart. It can be configured to use
* `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value.
* Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):
*
*
*
* Requires compared column types to implement `-`, `*`, `<`, `==`, and `abs`.
*/
def epsilon(epsilon: Double): EpsilonDiffComparator = EpsilonDiffComparator(epsilon)
/**
* A comparator for string values.
*
* With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing
* whitespaces as well. With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see
* [[default()]]).
*/
def string(whitespaceAgnostic: Boolean = true): StringDiffComparator =
if (whitespaceAgnostic) {
WhitespaceDiffComparator
} else {
StringDiffComparator
}
/**
* This comparator considers two `DateType` or `TimestampType` values equal when they are at most `duration` apart.
* Duration is an instance of `java.time.Duration`.
*
* The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal
* (`.asExclusive()`):
`DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`
*
`DiffComparator.duration(duration).asExclusive()`: `left - right < duration`
*/
def duration(duration: Duration): DurationDiffComparator = DurationDiffComparator(duration)
/**
* This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values.
*/
def map[K: Encoder, V: Encoder](): DiffComparator = MapDiffComparator[K, V](keyOrderSensitive = false)
/**
* This comparator compares two `Map[keyType,valueType]` values. They are equal when they match in all their keys and
* values.
*/
def map(keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean = false): DiffComparator =
MapDiffComparator(keyType, valueType, keyOrderSensitive)
// for backward compatibility to v2.4.0 up to v2.8.0
// replace with default value in above map when moving to v3
/**
* This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values.
*
* @param keyOrderSensitive
* comparator compares key order if true
*/
def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator =
MapDiffComparator[K, V](keyOrderSensitive)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types.{DataType, StructField}
import uk.co.gresearch.spark.diff
import uk.co.gresearch.spark.diff.DiffMode.{Default, DiffMode}
import uk.co.gresearch.spark.diff.comparator.{
DefaultDiffComparator,
DiffComparator,
EquivDiffComparator,
TypedDiffComparator
}
import scala.annotation.varargs
import scala.collection.Map
/**
* The diff mode determines the output columns of the diffing transformation.
*/
object DiffMode extends Enumeration {
type DiffMode = Value
/**
* The diff mode determines the output columns of the diffing transformation.
*
* - ColumnByColumn: The diff contains value columns from the left and right dataset, arranged column by column:
* diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, …
*
* - SideBySide: The diff contains value columns from the left and right dataset, arranged side by side: diff,(
* changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, …
* - LeftSide / RightSide: The diff contains value columns from the left / right dataset only.
*/
val ColumnByColumn, SideBySide, LeftSide, RightSide = Value
/**
* The diff mode determines the output columns of the diffing transformation. The default diff mode is ColumnByColumn.
*
* Default is not a enum value here (hence the def) so that we do not have to include it in every match clause. We
* will see the respective enum value that Default points to instead.
*/
def Default: diff.DiffMode.Value = ColumnByColumn
// we want to return Default's enum value for 'Default' here but cannot override super.withName.
def withNameOption(name: String): Option[Value] = {
if ("Default".equals(name)) {
Some(DiffMode.Default)
} else {
try {
Some(super.withName(name))
} catch {
case _: NoSuchElementException => None
}
}
}
}
/**
* Configuration class for diffing Datasets.
*
* @param diffColumn
* name of the diff column
* @param leftColumnPrefix
* prefix of columns from the left Dataset
* @param rightColumnPrefix
* prefix of columns from the right Dataset
* @param insertDiffValue
* value in diff column for inserted rows
* @param changeDiffValue
* value in diff column for changed rows
* @param deleteDiffValue
* value in diff column for deleted rows
* @param nochangeDiffValue
* value in diff column for un-changed rows
* @param changeColumn
* name of change column
* @param diffMode
* diff output format
* @param sparseMode
* un-changed values are null on both sides
* @param defaultComparator
* default custom comparator
* @param dataTypeComparators
* custom comparator for some data type
* @param columnNameComparators
* custom comparator for some column name
*/
case class DiffOptions(
diffColumn: String,
leftColumnPrefix: String,
rightColumnPrefix: String,
insertDiffValue: String,
changeDiffValue: String,
deleteDiffValue: String,
nochangeDiffValue: String,
changeColumn: Option[String] = None,
diffMode: DiffMode = Default,
sparseMode: Boolean = false,
defaultComparator: DiffComparator = DefaultDiffComparator,
dataTypeComparators: Map[DataType, DiffComparator] = Map.empty,
columnNameComparators: Map[String, DiffComparator] = Map.empty
) {
// Constructor for Java to construct default options
def this() = this("diff", "left", "right", "I", "C", "D", "N")
def this(
diffColumn: String,
leftColumnPrefix: String,
rightColumnPrefix: String,
insertDiffValue: String,
changeDiffValue: String,
deleteDiffValue: String,
nochangeDiffValue: String,
changeColumn: Option[String],
diffMode: DiffMode,
sparseMode: Boolean
) = {
this(
diffColumn,
leftColumnPrefix,
rightColumnPrefix,
insertDiffValue,
changeDiffValue,
deleteDiffValue,
nochangeDiffValue,
changeColumn,
diffMode,
sparseMode,
DefaultDiffComparator,
Map.empty,
Map.empty
)
}
require(leftColumnPrefix.nonEmpty, "Left column prefix must not be empty")
require(rightColumnPrefix.nonEmpty, "Right column prefix must not be empty")
require(
handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix),
s"Left and right column prefix must be distinct: $leftColumnPrefix"
)
val diffValues = Seq(insertDiffValue, changeDiffValue, deleteDiffValue, nochangeDiffValue)
require(diffValues.distinct.length == diffValues.length, s"Diff values must be distinct: $diffValues")
require(
!changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)),
s"Change column name must be different to diff column: $diffColumn"
)
/**
* Fluent method to change the diff column name. Returns a new immutable DiffOptions instance with the new diff column
* name.
* @param diffColumn
* new diff column name
* @return
* new immutable DiffOptions instance
*/
def withDiffColumn(diffColumn: String): DiffOptions = {
this.copy(diffColumn = diffColumn)
}
/**
* Fluent method to change the prefix of columns from the left Dataset. Returns a new immutable DiffOptions instance
* with the new column prefix.
* @param leftColumnPrefix
* new column prefix
* @return
* new immutable DiffOptions instance
*/
def withLeftColumnPrefix(leftColumnPrefix: String): DiffOptions = {
this.copy(leftColumnPrefix = leftColumnPrefix)
}
/**
* Fluent method to change the prefix of columns from the right Dataset. Returns a new immutable DiffOptions instance
* with the new column prefix.
* @param rightColumnPrefix
* new column prefix
* @return
* new immutable DiffOptions instance
*/
def withRightColumnPrefix(rightColumnPrefix: String): DiffOptions = {
this.copy(rightColumnPrefix = rightColumnPrefix)
}
/**
* Fluent method to change the value of inserted rows in the diff column. Returns a new immutable DiffOptions instance
* with the new diff value.
* @param insertDiffValue
* new diff value
* @return
* new immutable DiffOptions instance
*/
def withInsertDiffValue(insertDiffValue: String): DiffOptions = {
this.copy(insertDiffValue = insertDiffValue)
}
/**
* Fluent method to change the value of changed rows in the diff column. Returns a new immutable DiffOptions instance
* with the new diff value.
* @param changeDiffValue
* new diff value
* @return
* new immutable DiffOptions instance
*/
def withChangeDiffValue(changeDiffValue: String): DiffOptions = {
this.copy(changeDiffValue = changeDiffValue)
}
/**
* Fluent method to change the value of deleted rows in the diff column. Returns a new immutable DiffOptions instance
* with the new diff value.
* @param deleteDiffValue
* new diff value
* @return
* new immutable DiffOptions instance
*/
def withDeleteDiffValue(deleteDiffValue: String): DiffOptions = {
this.copy(deleteDiffValue = deleteDiffValue)
}
/**
* Fluent method to change the value of un-changed rows in the diff column. Returns a new immutable DiffOptions
* instance with the new diff value.
* @param nochangeDiffValue
* new diff value
* @return
* new immutable DiffOptions instance
*/
def withNochangeDiffValue(nochangeDiffValue: String): DiffOptions = {
this.copy(nochangeDiffValue = nochangeDiffValue)
}
/**
* Fluent method to change the change column name. Returns a new immutable DiffOptions instance with the new change
* column name.
* @param changeColumn
* new change column name
* @return
* new immutable DiffOptions instance
*/
def withChangeColumn(changeColumn: String): DiffOptions = {
this.copy(changeColumn = Some(changeColumn))
}
/**
* Fluent method to remove change column. Returns a new immutable DiffOptions instance without a change column.
* @return
* new immutable DiffOptions instance
*/
def withoutChangeColumn(): DiffOptions = {
this.copy(changeColumn = None)
}
/**
* Fluent method to change the diff mode. Returns a new immutable DiffOptions instance with the new diff mode.
* @return
* new immutable DiffOptions instance
*/
def withDiffMode(diffMode: DiffMode): DiffOptions = {
this.copy(diffMode = diffMode)
}
/**
* Fluent method to change the sparse mode. Returns a new immutable DiffOptions instance with the new sparse mode.
* @return
* new immutable DiffOptions instance
*/
def withSparseMode(sparseMode: Boolean): DiffOptions = {
this.copy(sparseMode = sparseMode)
}
/**
* Fluent method to add a default comparator. Returns a new immutable DiffOptions instance with the new default
* comparator.
* @return
* new immutable DiffOptions instance
*/
def withDefaultComparator(diffComparator: DiffComparator): DiffOptions = {
this.copy(defaultComparator = diffComparator)
}
/**
* Fluent method to add a typed equivalent operator as a default comparator. The encoder defines the input type of the
* comparator. Returns a new immutable DiffOptions instance with the new default comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
def withDefaultComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv))
}
/**
* Fluent method to add a typed equivalent operator as a default comparator. Returns a new immutable DiffOptions
* instance with the new default comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
def withDefaultComparator[T](equiv: math.Equiv[T], inputDataType: DataType): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv, inputDataType))
}
/**
* Fluent method to add an equivalent operator as a default comparator. Returns a new immutable DiffOptions instance
* with the new default comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
def withDefaultComparator(equiv: math.Equiv[Any]): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv))
}
/**
* Fluent method to add a comparator for its input data type. Returns a new immutable DiffOptions instance with the
* new comparator.
* @return
* new immutable DiffOptions instance
*/
def withComparator(diffComparator: TypedDiffComparator): DiffOptions = {
if (dataTypeComparators.contains(diffComparator.inputType)) {
throw new IllegalArgumentException(s"A comparator for data type ${diffComparator.inputType} exists already.")
}
this.copy(dataTypeComparators = dataTypeComparators ++ Map(diffComparator.inputType -> diffComparator))
}
/**
* Fluent method to add a comparator for one or more data types. Returns a new immutable DiffOptions instance with the
* new comparator.
* @return
* new immutable DiffOptions instance
*/
@varargs
def withComparator(diffComparator: DiffComparator, dataType: DataType, dataTypes: DataType*): DiffOptions = {
val allDataTypes = dataType +: dataTypes
diffComparator match {
case typed: TypedDiffComparator if allDataTypes.exists(_ != typed.inputType) =>
throw new IllegalArgumentException(
s"Comparator with input type ${typed.inputType.simpleString} " +
s"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(", ")}"
)
case _ =>
}
val existingDataTypes = allDataTypes.filter(dataTypeComparators.contains)
if (existingDataTypes.nonEmpty) {
throw new IllegalArgumentException(
s"A comparator for data type${if (existingDataTypes.length > 1) "s" else ""} " +
s"${existingDataTypes.map(_.simpleString).sorted.mkString(", ")} exists already."
)
}
this.copy(dataTypeComparators = dataTypeComparators ++ allDataTypes.map(dt => dt -> diffComparator))
}
/**
* Fluent method to add a comparator for one or more column names. Returns a new immutable DiffOptions instance with
* the new comparator.
* @return
* new immutable DiffOptions instance
*/
@varargs
def withComparator(diffComparator: DiffComparator, columnName: String, columnNames: String*): DiffOptions = {
val allColumnNames = columnName +: columnNames
val existingColumnNames = allColumnNames.filter(columnNameComparators.contains)
if (existingColumnNames.nonEmpty) {
throw new IllegalArgumentException(
s"A comparator for column name${if (existingColumnNames.length > 1) "s" else ""} " +
s"${existingColumnNames.sorted.mkString(", ")} exists already."
)
}
this.copy(columnNameComparators = columnNameComparators ++ allColumnNames.map(name => name -> diffComparator))
}
/**
* Fluent method to add a typed equivalent operator as a comparator for its input data type. The encoder defines the
* input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
def withComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions =
withComparator(EquivDiffComparator(equiv))
/**
* Fluent method to add a typed equivalent operator as a comparator for one or more column names. The encoder defines
* the input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
def withComparator[T: Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions =
withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)
/**
* Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable
* DiffOptions instance with the new comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @note
* Java-specific method
* @return
* new immutable DiffOptions instance
*/
@varargs
def withComparator[T](
equiv: math.Equiv[T],
encoder: Encoder[T],
columnName: String,
columnNames: String*
): DiffOptions =
withComparator(EquivDiffComparator(equiv)(encoder), columnName, columnNames: _*)
/**
* Fluent method to add an equivalent operator as a comparator for one or more data types. Returns a new immutable
* DiffOptions instance with the new comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
// There is probably no use case of calling this with multiple datatype while T not being Any
// But this is the only way to define withComparator[T](equiv: math.Equiv[T], dataType: DataType)
// without being ambiguous with withComparator(equiv: math.Equiv[Any], dataType: DataType, dataTypes: DataType*)
@varargs
def withComparator[T](equiv: math.Equiv[T], dataType: DataType, dataTypes: DataType*): DiffOptions =
(dataType +: dataTypes).foldLeft(this)((options, dataType) =>
options.withComparator(EquivDiffComparator(equiv, dataType))
)
/**
* Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable
* DiffOptions instance with the new comparator.
* @note
* The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
* `DiffComparator` interface.
* @return
* new immutable DiffOptions instance
*/
@varargs
def withComparator(equiv: math.Equiv[Any], columnName: String, columnNames: String*): DiffOptions =
withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)
private[diff] def comparatorFor(column: StructField): DiffComparator =
columnNameComparators
.get(column.name)
.orElse(dataTypeComparators.get(column.dataType))
.getOrElse(defaultComparator)
}
object DiffOptions {
/**
* Default diffing options.
*/
val default: DiffOptions = new DiffOptions()
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/DefaultDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
case object DefaultDiffComparator extends DiffComparator {
override def equiv(left: Column, right: Column): Column = NullSafeEqualDiffComparator.equiv(left, right)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/DiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
trait DiffComparator {
def equiv(left: Column, right: Column): Column
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.abs
import uk.co.gresearch.spark
import uk.co.gresearch.spark.SparkVersion
import uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isNotSupportedBySpark
import java.time.Duration
/**
* Compares two timestamps and considers them equal when they are less than (or equal to when inclusive = true) a given
* duration apart.
*
* @param duration
* equality threshold
* @param inclusive
* duration is considered equal when true
*/
case class DurationDiffComparator(duration: Duration, inclusive: Boolean = true) extends DiffComparator {
if (isNotSupportedBySpark) {
throw new UnsupportedOperationException(
s"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}"
)
}
override def equiv(left: Column, right: Column): Column = {
val inDuration =
if (inclusive)
(diff: Column) => diff <= duration
else
(diff: Column) => diff < duration
left.isNull && right.isNull ||
left.isNotNull && right.isNotNull && inDuration(abs(left - right))
}
def asInclusive(): DurationDiffComparator = if (inclusive) this else copy(inclusive = true)
def asExclusive(): DurationDiffComparator = if (inclusive) copy(inclusive = false) else this
}
object DurationDiffComparator extends SparkVersion {
val isSupportedBySpark: Boolean = SparkMajorVersion == 3 && SparkMinorVersion >= 3 || SparkMajorVersion > 3
val isNotSupportedBySpark: Boolean = !isSupportedBySpark
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{abs, greatest}
case class EpsilonDiffComparator(epsilon: Double, relative: Boolean = true, inclusive: Boolean = true)
extends DiffComparator {
override def equiv(left: Column, right: Column): Column = {
val threshold =
if (relative)
greatest(abs(left), abs(right)) * epsilon
else
epsilon
val inEpsilon =
if (inclusive)
(diff: Column) => diff <= threshold
else
(diff: Column) => diff < threshold
left.isNull && right.isNull || left.isNotNull && right.isNotNull && inEpsilon(abs(left - right))
}
def asAbsolute(): EpsilonDiffComparator = if (relative) copy(relative = false) else this
def asRelative(): EpsilonDiffComparator = if (relative) this else copy(relative = true)
def asInclusive(): EpsilonDiffComparator = if (inclusive) this else copy(inclusive = true)
def asExclusive(): EpsilonDiffComparator = if (inclusive) copy(inclusive = false) else this
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BinaryOperator, Expression}
import org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension}
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.sql.{Column, Encoder}
trait EquivDiffComparator[T] extends DiffComparator {
val equiv: math.Equiv[T]
}
private trait ExpressionEquivDiffComparator[T] extends EquivDiffComparator[T] {
def equiv(left: Expression, right: Expression): EquivExpression[T]
def equiv(left: Column, right: Column): Column = equiv(left.expr, right.expr).column
}
trait TypedEquivDiffComparator[T] extends EquivDiffComparator[T] with TypedDiffComparator
private[comparator] trait TypedEquivDiffComparatorWithInput[T]
extends ExpressionEquivDiffComparator[T]
with TypedEquivDiffComparator[T] {
def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)
}
private[comparator] case class InputTypedEquivDiffComparator[T](equiv: math.Equiv[T], inputType: DataType)
extends TypedEquivDiffComparatorWithInput[T]
object EquivDiffComparator {
def apply[T: Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv)
def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] =
InputTypedEquivDiffComparator(equiv, inputType)
def apply(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivAnyDiffComparator(equiv)
private case class EncoderEquivDiffComparator[T: Encoder](equiv: math.Equiv[T])
extends ExpressionEquivDiffComparator[T]
with TypedEquivDiffComparator[T] {
override def inputType: DataType = encoderFor[T].schema.fields(0).dataType
def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)
}
private case class EquivAnyDiffComparator(equiv: math.Equiv[Any]) extends ExpressionEquivDiffComparator[Any] {
def equiv(left: Expression, right: Expression): EquivExpression[Any] = EquivAny(left, right, equiv)
}
}
private trait EquivExpression[T] extends BinaryExpression {
val equiv: math.Equiv[T]
override def nullable: Boolean = false
override def dataType: DataType = BooleanType
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input).asInstanceOf[T]
val input2 = right.eval(input).asInstanceOf[T]
if (input1 == null && input2 == null) {
true
} else if (input1 == null || input2 == null) {
false
} else {
equiv.equiv(input1, input2)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equivRef = ctx.addReferenceObj("equiv", equiv, math.Equiv.getClass.getName.stripSuffix("$"))
ev.copy(
code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
(!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));""",
isNull = FalseLiteral
)
}
}
private trait EquivOperator[T] extends BinaryOperator with EquivExpression[T] {
val equivInputType: DataType
override def inputType: DataType = equivInputType
override def symbol: String = "≡"
}
private case class Equiv[T](left: Expression, right: Expression, equiv: math.Equiv[T], equivInputType: DataType)
extends EquivOperator[T] {
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Equiv[T] =
copy(left = newLeft, right = newRight)
}
private case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any]) extends EquivExpression[Any] {
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): EquivAny =
copy(left = newLeft, right = newRight)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.UnsafeMapData
import org.apache.spark.sql.types.{DataType, MapType}
import org.apache.spark.sql.{Column, Encoder}
import scala.reflect.ClassTag
case class MapDiffComparator[K, V](private val comparator: EquivDiffComparator[UnsafeMapData]) extends DiffComparator {
override def equiv(left: Column, right: Column): Column = comparator.equiv(left, right)
}
private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean)
extends math.Equiv[UnsafeMapData] {
override def equiv(left: UnsafeMapData, right: UnsafeMapData): Boolean = {
val leftKeys: Array[K] = left.keyArray().toArray(keyType)
val rightKeys: Array[K] = right.keyArray().toArray(keyType)
val leftKeysIndices: Map[K, Int] = leftKeys.zipWithIndex.toMap
val rightKeysIndices: Map[K, Int] = rightKeys.zipWithIndex.toMap
val leftValues = left.valueArray()
val rightValues = right.valueArray()
// can only be evaluated when right has same keys as left
lazy val valuesAreEqual = leftKeysIndices
.map { case (key, index) => index -> rightKeysIndices(key) }
.map { case (leftIndex, rightIndex) =>
(leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex))
}
.map { case (leftIndex, rightIndex, leftIsNull, rightIsNull) =>
leftIsNull && rightIsNull ||
!leftIsNull && !rightIsNull && leftValues
.get(leftIndex, valueType)
.equals(rightValues.get(rightIndex, valueType))
}
left.numElements() == right.numElements() &&
(keyOrderSensitive && leftKeys
.sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) &&
valuesAreEqual.forall(identity)
}
}
case object MapDiffComparator {
def apply[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): MapDiffComparator[K, V] = {
val keyType = encoderFor[K].schema.fields(0).dataType
val valueType = encoderFor[V].schema.fields(0).dataType
val equiv = MapDiffEquiv(keyType, valueType, keyOrderSensitive)
val dataType = MapType(keyType, valueType)
val comparator = InputTypedEquivDiffComparator[UnsafeMapData](equiv, dataType)
MapDiffComparator[K, V](comparator)
}
def apply(keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean): MapDiffComparator[Any, Any] = {
val equiv = MapDiffEquiv(keyType, valueType, keyOrderSensitive)
val dataType = MapType(keyType, valueType)
val comparator = InputTypedEquivDiffComparator[UnsafeMapData](equiv, dataType)
MapDiffComparator[Any, Any](comparator)
}
// for backward compatibility to v2.4.0 up to v2.8.0
// replace with default value in above apply when moving to v3
def apply[K: Encoder, V: Encoder](): MapDiffComparator[K, V] = apply(keyOrderSensitive = false)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/NullSafeEqualDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
case object NullSafeEqualDiffComparator extends DiffComparator {
override def equiv(left: Column, right: Column): Column = left <=> right
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/TypedDiffComparator.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.{DataType, StringType}
trait TypedDiffComparator extends DiffComparator {
def inputType: DataType
}
trait StringDiffComparator extends TypedDiffComparator {
override def inputType: DataType = StringType
}
case object StringDiffComparator extends StringDiffComparator {
override def equiv(left: Column, right: Column): Column = DefaultDiffComparator.equiv(left, right)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff.comparator
import org.apache.spark.unsafe.types.UTF8String
case object WhitespaceDiffComparator extends TypedEquivDiffComparatorWithInput[UTF8String] with StringDiffComparator {
override val equiv: scala.Equiv[UTF8String] = (x: UTF8String, y: UTF8String) =>
x.trimAll()
.toString
.replaceAll("\\s+", " ")
.equals(
y.trimAll().toString.replaceAll("\\s+", " ")
)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/diff/package.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.{DataFrame, Dataset, Encoder}
import java.util.Locale
package object diff {
implicit class DatasetDiff[T](ds: Dataset[T]) {
/**
* Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.
* Both Datasets must contain the same set of column names and data types. The order of columns in the two Datasets
* is not important as one column is compared to the column with the same name of the other Dataset, not the column
* with the same position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of
* the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as
* `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete.
*
* If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will
* appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* df1.diff(df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* df1.diff(df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset
* are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*
* The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name
* containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diff(other: Dataset[T], idColumns: String*): DataFrame = {
Diff.of(this.ds, other, idColumns: _*)
}
/**
* Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
* Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
* order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
* position.
*
* Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
* between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of
* the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as
* `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete.
*
* If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will
* appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
* The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
* strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
* val df2 = Seq((1, "one"), (2, "Two"), (4, "four")).toDF("id", "value")
*
* df1.diff(df2).show()
*
* // output:
* // +----+---+-----+
* // |diff| id|value|
* // +----+---+-----+
* // | N| 1| one|
* // | D| 2| two|
* // | I| 2| Two|
* // | D| 3|three|
* // | I| 4| four|
* // +----+---+-----+
*
* df1.diff(df2, "id").show()
*
* // output:
* // +----+---+----------+-----------+
* // |diff| id|left_value|right_value|
* // +----+---+----------+-----------+
* // | N| 1| one| one|
* // | C| 2| two| Two|
* // | D| 3| three| null|
* // | I| 4| null| four|
* // +----+---+----------+-----------+
*
* }}}
*
* The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset
* are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*
* The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name
* containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
*/
def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = {
Diff.of(this.ds, other, idColumns, ignoreColumns)
}
/**
* Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* The schema of the returned DataFrame can be configured by the given `DiffOptions`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diff(other: Dataset[T], options: DiffOptions, idColumns: String*): DataFrame = {
new Differ(options).diff(this.ds, other, idColumns: _*)
}
/**
* Returns a new DataFrame that contains the differences between this and the other Dataset of similar types `T` and
* `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* The schema of the returned DataFrame can be configured by the given `DiffOptions`.
*/
def diff[U](
other: Dataset[U],
options: DiffOptions,
idColumns: Seq[String],
ignoreColumns: Seq[String]
): DataFrame = {
new Differ(options).diff(this.ds, other, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* This requires an additional implicit `Encoder[U]` for the return type `Dataset[U]`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V] = {
Diff.ofAs(this.ds, other, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit
diffEncoder: Encoder[V]
): Dataset[V] = {
Diff.ofAs(this.ds, other, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
* Dataset can be configured by the given `DiffOptions`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*)(implicit
diffEncoder: Encoder[V]
): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
* Dataset can be configured by the given `DiffOptions`.
*/
def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String])(
implicit diffEncoder: Encoder[V]
): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diffAs[V](other: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
Diff.ofAs(this.ds, other, diffEncoder, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
def diffAs[U, V](
other: Dataset[U],
diffEncoder: Encoder[V],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[V] = {
Diff.ofAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
* Dataset can be configured by the given `DiffOptions`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diffAs[V](other: Dataset[T], options: DiffOptions, diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
* Dataset can be configured by the given `DiffOptions`.
*/
def diffAs[U, V](
other: Dataset[U],
options: DiffOptions,
diffEncoder: Encoder[V],
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as
* tuples of type `(String, T, T)`.
*
* See `diff(Dataset[T], Seq[String])`.
*/
def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)] =
Diff.default.diffWith(this.ds, other, idColumns: _*)
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U` as tuples of type `(String, T, U)`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*/
def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] =
Diff.default.diffWith(this.ds, other, idColumns, ignoreColumns)
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as
* tuples of type `(String, T, T)`.
*
* See `diff(Dataset[T], String*)`.
*
* The schema of the returned Dataset can be configured by the given `DiffOptions`.
*/
def diffWith(other: Dataset[T], options: DiffOptions, idColumns: String*): Dataset[(String, T, T)] = {
new Differ(options).diffWith(this.ds, other, idColumns: _*)
}
/**
* Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
* `U`. as tuples of type `(String, T, T)`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* The schema of the returned Dataset can be configured by the given `DiffOptions`.
*/
def diffWith[U](
other: Dataset[U],
options: DiffOptions,
idColumns: Seq[String],
ignoreColumns: Seq[String]
): Dataset[(String, T, U)] = {
new Differ(options).diffWith(this.ds, other, idColumns, ignoreColumns)
}
}
/**
* Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is
* deactivated, it lower-cases the given column name and no-ops otherwise.
*
* @param columnName
* column name
* @return
* case sensitive or insensitive column name
*/
private[diff] def handleConfiguredCaseSensitivity(columnName: String): String =
if (SQLConf.get.caseSensitiveAnalysis) columnName else columnName.toLowerCase(Locale.ROOT)
implicit class CaseInsensitiveSeq(seq: Seq[String]) {
def containsCaseSensitivity(string: String): Boolean =
seq.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))
def filterIsInCaseSensitivity(other: Iterable[String]): Seq[String] = {
val otherSet = other.map(handleConfiguredCaseSensitivity).toSet
seq.filter(v => otherSet.contains(handleConfiguredCaseSensitivity(v)))
}
def diffCaseSensitivity(other: Iterable[String]): Seq[String] = {
val otherSet = other.map(handleConfiguredCaseSensitivity).toSet
seq.filter(v => !otherSet.contains(handleConfiguredCaseSensitivity(v)))
}
}
implicit class CaseInsensitiveArray(array: Array[String]) {
def containsCaseSensitivity(string: String): Boolean =
array.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))
def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] =
array.toSeq.filterIsInCaseSensitivity(other).toArray
def diffCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.diffCaseSensitivity(other).toArray
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/group/package.scala
================================================
/*
* Copyright 2022 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders}
import uk.co.gresearch.ExtendedAny
package object group {
/**
* This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups, while providing
* a sorted iterator over group values.
*
* The key-value Dataset given the constructor has to be partitioned by the key and sorted within partitions by the
* key and value.
*
* @param ds
* the properly partitioned and sorted dataset
* @tparam K
* type of the keys with ordering and encoder
* @tparam V
* type of the values with encoder
*/
case class SortedGroupByDataset[K: Ordering: Encoder, V: Encoder] private (ds: Dataset[(K, V)]) {
/**
* (Scala-specific) Applies the given function to each group of data. For each unique group, the function will be
* passed the group key and a sorted iterator that contains all of the elements in the group. The function can
* return an iterator containing elements of an arbitrary type which will be returned as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all the data in the
* [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce
* function or an `org.apache.spark.sql.expressions#Aggregator`.
*
* Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,
* users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)
* unless they are sure that this is possible given the memory constraints of their cluster.
*/
def flatMapSortedGroups[W: Encoder](func: (K, Iterator[V]) => TraversableOnce[W]): Dataset[W] =
ds.mapPartitions(new GroupedIterator(_).flatMap(v => func(v._1, v._2)))
/**
* (Scala-specific) Applies the given function to each group of data. For each unique group, the function s will be
* passed the group key to create a state instance, while the function func will be passed that state instance and
* group values in sequence according to the sort order in the groups. The function func can return an iterator
* containing elements of an arbitrary type which will be returned as a new [[Dataset]].
*
* This function does not support partial aggregation, and as a result requires shuffling all the data in the
* [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce
* function or an `org.apache.spark.sql.expressions#Aggregator`.
*
* Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,
* users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)
* unless they are sure that this is possible given the memory constraints of their cluster.
*/
def flatMapSortedGroups[S, W: Encoder](s: K => S)(func: (S, V) => TraversableOnce[W]): Dataset[W] = {
ds.mapPartitions(new GroupedIterator(_).flatMap { case (k, it) =>
val state = s(k)
it.flatMap(v => func(state, v))
})
}
}
private[spark] object SortedGroupByDataset {
def apply[K: Ordering: Encoder, V](
ds: Dataset[V],
groupColumns: Seq[Column],
orderColumns: Seq[Column],
partitions: Option[Int]
): SortedGroupByDataset[K, V] = {
// make ds encoder implicitly available
implicit val valueEncoder: Encoder[V] = ds.encoder
// multiple group columns are turned into a tuple,
// while a single group column is taken as is
val keyColumn =
if (groupColumns.length == 1)
groupColumns.head
else
struct(groupColumns: _*)
// all columns are turned into a single column as a struct
val valColumn = struct(col("*"))
// repartition by group columns with given number of partitions (if given)
// sort within partitions by group and order columns
// finally, turn key and value into typed classes
val grouped = ds
.on(partitions.isDefined)
.either(_.repartition(partitions.get, groupColumns: _*))
.or(_.repartition(groupColumns: _*))
.sortWithinPartitions(groupColumns ++ orderColumns: _*)
.select(
keyColumn.as("key").as[K],
valColumn.as("value").as[V]
)
SortedGroupByDataset(grouped)
}
def apply[K: Ordering: Encoder, V, O: Encoder](
ds: Dataset[V],
key: V => K,
order: V => O,
partitions: Option[Int],
reverse: Boolean
): SortedGroupByDataset[K, V] = {
// prepare encoder needed for this exercise
val keyEncoder: Encoder[K] = implicitly[Encoder[K]]
implicit val valueEncoder: Encoder[V] = ds.encoder
val orderEncoder: Encoder[O] = implicitly[Encoder[O]]
implicit val kvEncoder: Encoder[(K, V)] = Encoders.tuple(keyEncoder, valueEncoder)
implicit val kvoEncoder: Encoder[(K, V, O)] = Encoders.tuple(keyEncoder, valueEncoder, orderEncoder)
// materialise the key and order class for each value
val kvo = ds.map(v => (key(v), v, order(v)))
// sort by key and order column
def keyColumn = col(kvo.columns.head)
def orderColumn = if (reverse) col(kvo.columns.last).desc else col(kvo.columns.last)
// repartition by group columns with given number of partitions (if given)
// sort within partitions by group and order columns
// finally, turn key and value into typed classes
val grouped = kvo
.on(partitions.isDefined)
.either(_.repartition(partitions.get, keyColumn))
.or(_.repartition(keyColumn))
.sortWithinPartitions(keyColumn, orderColumn)
.map(v => (v._1, v._2))
SortedGroupByDataset(grouped)
}
}
private[group] class GroupedIterator[K: Ordering, V](iter: Iterator[(K, V)]) extends Iterator[(K, Iterator[V])] {
private val values = iter.buffered
private var currentKey: Option[K] = None
private var currentGroup: Option[Iterator[V]] = None
override def hasNext: Boolean = {
if (currentKey.isEmpty) {
if (currentGroup.isDefined) {
// consume current group
val it = currentGroup.get
while (it.hasNext) it.next
currentGroup = None
}
if (values.hasNext) {
currentKey = Some(values.head._1)
currentGroup = Some(new GroupIterator(values))
}
}
currentKey.isDefined
}
override def next(): (K, Iterator[V]) = {
try {
(currentKey.get, currentGroup.get)
} finally {
currentKey = None
}
}
}
private[group] class GroupIterator[K: Ordering, V](iter: BufferedIterator[(K, V)]) extends Iterator[V] {
private val ordering = implicitly[Ordering[K]]
private val key = iter.head._1
private def identicalKeys(one: K, two: K): Boolean =
one == null && two == null || one != null && two != null && ordering.equiv(one, two)
override def hasNext: Boolean = iter.hasNext && identicalKeys(iter.head._1, key)
override def next(): V = iter.next._2
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/package.scala
================================================
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.ColumnName
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, UnixMicros}
import org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension}
import org.apache.spark.sql.functions.{col, count, lit, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, LongType, TimestampType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SparkContext, SparkFiles}
import uk.co.gresearch.spark.group.SortedGroupByDataset
import java.nio.file.{Files, Paths}
package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Provides a prefix that makes any string distinct w.r.t. the given strings.
* @param existing
* strings
* @return
* distinct prefix
*/
private[spark] def distinctPrefixFor(existing: Seq[String]): String = {
// count number of suffix _ for each existing column name
// return string with one more _ than that
"_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)
}
/**
* Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.
* @param prefix
* prefix string of temporary directory name
* @return
* absolute path of temporary directory
*/
def createTemporaryDir(prefix: String): String = {
// SparkFiles.getRootDirectory() will be deleted on spark application shutdown
Files.createTempDirectory(Paths.get(SparkFiles.getRootDirectory()), prefix).toAbsolutePath.toString
}
// https://issues.apache.org/jira/browse/SPARK-40588
private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {
ds.sparkSession.conf
.get(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,
SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString
)
.equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver =>
Set("3.0.", "3.1.", "3.2.0", "3.2.1", "3.2.2", "3.3.0", "3.3.1").exists(pat =>
if (pat.endsWith(".")) { ver.startsWith(pat) }
else { ver.equals(pat) || ver.startsWith(pat + "-") }
)
)
}
private[spark] def info(msg: String): Unit = logInfo(msg)
private[spark] def warning(msg: String): Unit = logWarning(msg)
/**
* Encloses the given strings with backticks (backquotes) if needed.
*
* Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore, and contain
* only letters, numbers and underscores.
*
* Multiple strings will be enclosed individually and concatenated with dots (`.`).
*
* This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.
*
* Examples:
* {{{
* col("a.column") // this references the field "column" of column "a"
* col("`a.column`") // this reference the column with the name "a.column"
* col(backticks("column")) // produces "column"
* col(backticks("a.column")) // produces "`a.column`"
* col(backticks("a column")) // produces "`a column`"
* col(backticks("`a.column`")) // produces "`a.column`"
* col(backticks("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
* @param string
* a string
* @param strings
* more strings
*/
@scala.annotation.varargs
def backticks(string: String, strings: String*): String =
Backticks.column_name(string, strings: _*)
/**
* Aggregate function: returns the number of items in a group that are not null.
*/
def count_null(e: Column): Column = count(when(e.isNull, lit(1)))
private val nanoSecondsPerDotNetTick: Long = 100
private val dotNetTicksPerSecond: Long = 10000000
private val unixEpochDotNetTicks: Long = 621355968000000000L
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number
* (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the
* timestamp (1/10 of a microsecond) is lost.
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToTimestamp($"ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:---------------------------|
* | 638155413748959318 | 2023-03-27 21:16:14.895931 |
*
* Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full
* precision of the tick timestamp.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumn
* column with a tick value
* @return
* result timestamp column
*/
def dotNetTicksToTimestamp(tickColumn: Column): Column =
dotNetTicksToUnixEpoch(tickColumn).cast(TimestampType)
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number
* (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the
* timestamp (1/10 of a microsecond) is lost.
*
* {{{
* df.select($"ticks", dotNetTicksToTimestamp("ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:---------------------------|
* | 638155413748959318 | 2023-03-27 21:16:14.895931 |
*
* Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full
* precision of the tick timestamp.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumnName
* name of a column with a tick value
* @return
* result timestamp column
*/
def dotNetTicksToTimestamp(tickColumnName: String): Column = dotNetTicksToTimestamp(col(tickColumnName))
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be convertible to a number
* (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpoch($"ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:---------------------|
* | 638155413748959318 | 1679944574.895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumn
* column with a tick value
* @return
* result unix epoch seconds column as decimal
*/
def dotNetTicksToUnixEpoch(tickColumn: Column): Column =
(tickColumn.cast(DecimalType(19, 0)) - unixEpochDotNetTicks) / dotNetTicksPerSecond
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number
* (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpoch("ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:---------------------|
* | 638155413748959318 | 1679944574.895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumnName
* name of column with a tick value
* @return
* result unix epoch seconds column as decimal
*/
def dotNetTicksToUnixEpoch(tickColumnName: String): Column = dotNetTicksToUnixEpoch(col(tickColumnName))
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number
* (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpochNanos($"ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:--------------------|
* | 638155413748959318 | 1679944574895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumn
* column with a tick value
* @return
* result unix epoch nanoseconds column as long
*/
def dotNetTicksToUnixEpochNanos(tickColumn: Column): Column = {
when(
tickColumn <= 713589688368547758L,
(tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick
)
}
/**
* Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be convertible to a
* number (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpochNanos("ticks").as("timestamp")).show(false)
* }}}
*
* | ticks | timestamp |
* |:-------------------|:--------------------|
* | 638155413748959318 | 1679944574895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param tickColumnName
* name of column with a tick value
* @return
* result unix epoch nanoseconds column as long
*/
def dotNetTicksToUnixEpochNanos(tickColumnName: String): Column = dotNetTicksToUnixEpochNanos(col(tickColumnName))
/**
* Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.
*
* Example:
* {{{
* df.select($"timestamp", timestampToDotNetTicks($"timestamp").as("ticks")).show(false)
* }}}
*
* | timestamp | ticks |
* |:---------------------------|:-------------------|
* | 2023-03-27 21:16:14.895931 | 638155413748959310 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param timestampColumn
* column with a timestamp value
* @return
* result tick value column
*/
def timestampToDotNetTicks(timestampColumn: Column): Column =
unixEpochTenthMicrosToDotNetTicks(UnixMicros(timestampColumn.expr).column * 10)
/**
* Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.
*
* Example:
* {{{
* df.select($"timestamp", timestampToDotNetTicks("timestamp").as("ticks")).show(false)
* }}}
*
* | timestamp | ticks |
* |:---------------------------|:-------------------|
* | 2023-03-27 21:16:14.895931 | 638155413748959310 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param timestampColumnName
* name of column with a timestamp value
* @return
* result tick value column
*/
def timestampToDotNetTicks(timestampColumnName: String): Column = timestampToDotNetTicks(col(timestampColumnName))
/**
* Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical
* unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be
* interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.
*
* Example:
* {{{
* df.select($"unix", unixEpochToDotNetTicks($"unix").as("ticks")).show(false)
* }}}
*
* | unix | ticks |
* |:------------------------------|:-------------------|
* | 1679944574.895931234000000000 | 638155413748959312 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param unixTimeColumn
* column with a unix epoch timestamp value
* @return
* result tick value column
*/
def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(
unixTimeColumn.cast(DecimalType(19, 7)) * 10000000
)
/**
* Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical
* unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be
* interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.
*
* Example:
* {{{
* df.select($"unix", unixEpochToDotNetTicks("unix").as("ticks")).show(false)
* }}}
*
* | unix | ticks |
* |:------------------------------|:-------------------|
* | 1679944574.895931234000000000 | 638155413748959312 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param unixTimeColumnName
* name of column with a unix epoch timestamp value
* @return
* result tick value column
*/
def unixEpochToDotNetTicks(unixTimeColumnName: String): Column = unixEpochToDotNetTicks(col(unixTimeColumnName))
/**
* Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not
* support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input
* column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.
*
* Example:
* {{{
* df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false)
* }}}
*
* | unix_nanos | ticks |
* |:--------------------|:-------------------|
* | 1679944574895931234 | 638155413748959312 |
*
* Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param unixNanosColumn
* column with a unix epoch timestamp value
* @return
* result tick value column
*/
def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(
unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick
)
/**
* Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not
* support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input
* column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.
*
* Example:
* {{{
* df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false)
* }}}
*
* | unix_nanos | ticks |
* |:--------------------|:-------------------|
* | 1679944574895931234 | 638155413748959312 |
*
* Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
* @param unixNanosColumnName
* name of column with a unix epoch timestamp value
* @return
* result tick value column
*/
def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks(
col(unixNanosColumnName)
)
private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column =
unixNanosColumn.cast(LongType) + unixEpochDotNetTicks
/**
* Set the job description and return the earlier description. Only set the description if it is not set.
*
* @param description
* job description
* @param ifNotSet
* job description is only set if no description is set yet
* @param context
* spark context
* @return
*/
def setJobDescription(description: String, ifNotSet: Boolean = false)(implicit context: SparkContext): String = {
val earlierDescriptionOption = Option(context.getLocalProperty("spark.job.description"))
if (earlierDescriptionOption.isEmpty || !ifNotSet) {
context.setJobDescription(description)
}
earlierDescriptionOption.orNull
}
/**
* Adds a job description to all Spark jobs started within the given function. The current Job description is restored
* after exit of the function.
*
* Usage example:
*
* {{{
* import uk.co.gresearch.spark._
*
* implicit val session: SparkSession = spark
*
* val count = withJobDescription("parquet file") {
* val df = spark.read.parquet("data.parquet")
* df.count
* }
* }}}
*
* With `ifNotSet == true`, the description is only set if no job description is set yet.
*
* Any modification to the job description during execution of the function is reverted, even if `ifNotSet == true`.
*
* @param description
* job description
* @param ifNotSet
* job description is only set if no description is set yet
* @param func
* code to execute while job description is set
* @param session
* spark session
* @tparam T
* return type of func
*/
def withJobDescription[T](description: String, ifNotSet: Boolean = false)(
func: => T
)(implicit session: SparkSession): T = {
val earlierDescription = setJobDescription(description, ifNotSet)(session.sparkContext)
try {
func
} finally {
setJobDescription(earlierDescription)(session.sparkContext)
}
}
/**
* Append the job description and return the earlier description.
*
* @param extraDescription
* job description
* @param separator
* separator to join exiting and extra description with
* @param context
* spark context
* @return
*/
def appendJobDescription(extraDescription: String, separator: String, context: SparkContext): String = {
val earlierDescriptionOption = Option(context.getLocalProperty("spark.job.description"))
val description = earlierDescriptionOption.map(_ + separator + extraDescription).getOrElse(extraDescription)
context.setJobDescription(description)
earlierDescriptionOption.orNull
}
/**
* Appends a job description to all Spark jobs started within the given function. The current Job description is
* extended by the separator and the extra description on entering the function, and restored after exit of the
* function.
*
* Usage example:
*
* {{{
* import uk.co.gresearch.spark._
*
* implicit val session: SparkSession = spark
*
* val count = appendJobDescription("parquet file") {
* val df = spark.read.parquet("data.parquet")
* appendJobDescription("count") {
* df.count
* }
* }
* }}}
*
* Any modification to the job description during execution of the function is reverted.
*
* @param extraDescription
* job description to be appended
* @param separator
* separator used when appending description
* @param func
* code to execute while job description is set
* @param session
* spark session
* @tparam T
* return type of func
*/
def appendJobDescription[T](extraDescription: String, separator: String = " - ")(
func: => T
)(implicit session: SparkSession): T = {
val earlierDescription = appendJobDescription(extraDescription, separator, session.sparkContext)
try {
func
} finally {
setJobDescription(earlierDescription)(session.sparkContext)
}
}
/**
* Class to extend a Spark Dataset.
*
* @param ds
* dataset
* @tparam V
* inner type of dataset
*/
@deprecated(
"Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.",
since = "2.9.0"
)
class ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]) {
private val eds = ExtendedDatasetV2[V](ds)
def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
eds.histogram(thresholds, valueColumn, aggregateColumns: _*)
def writePartitionedBy(
partitionColumns: Seq[Column],
moreFileColumns: Seq[Column] = Seq.empty,
moreFileOrder: Seq[Column] = Seq.empty,
partitions: Option[Int] = None,
writtenProjection: Option[Seq[Column]] = None,
unpersistHandle: Option[UnpersistHandle] = None
): DataFrameWriter[Row] =
eds.writePartitionedBy(
partitionColumns,
moreFileColumns,
moreFileOrder,
partitions,
writtenProjection,
unpersistHandle
)
def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] =
eds.groupBySorted(cols: _*)(order: _*)
def groupBySorted[K: Ordering: Encoder](partitions: Int)(cols: Column*)(
order: Column*
): SortedGroupByDataset[K, V] =
eds.groupBySorted(partitions)(cols: _*)(order: _*)
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
order: V => O
): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, Some(partitions))(order)
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
order: V => O,
reverse: Boolean
): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, Some(partitions))(order, reverse)
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(
order: V => O,
reverse: Boolean = false
): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, partitions)(order, reverse)
def withRowNumbers(order: Column*): DataFrame =
eds.withRowNumbers(order: _*)
def withRowNumbers(rowNumberColumnName: String, order: Column*): DataFrame =
eds.withRowNumbers(rowNumberColumnName, order: _*)
def withRowNumbers(storageLevel: StorageLevel, order: Column*): DataFrame =
eds.withRowNumbers(storageLevel, order: _*)
def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(unpersistHandle, order: _*)
def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =
eds.withRowNumbers(rowNumberColumnName, storageLevel, order: _*)
def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(rowNumberColumnName, unpersistHandle, order: _*)
def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(storageLevel, unpersistHandle, order: _*)
def withRowNumbers(
rowNumberColumnName: String,
storageLevel: StorageLevel,
unpersistHandle: UnpersistHandle,
order: Column*
): DataFrame =
eds.withRowNumbers(rowNumberColumnName, storageLevel, unpersistHandle, order: _*)
}
/**
* Class to extend a Spark Dataset.
*
* @param ds
* dataset
* @tparam V
* inner type of dataset
*/
def ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]): ExtendedDataset[V] = new ExtendedDataset(ds, encoder)
/**
* Implicit class to extend a Spark Dataset.
*
* @param ds
* dataset
* @tparam V
* inner type of dataset
*/
implicit class ExtendedDatasetV2[V](ds: Dataset[V]) {
private implicit val encoder: Encoder[V] = ds.encoder
/**
* Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
* ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
* in thresholds, there will be a column named s"≤threshold". There will also be a final column called
* s">last_threshold", that counts the remaining values that exceed the last threshold.
*
* @param thresholds
* sequence of thresholds, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn
* histogram is computed for values of this column
* @param aggregateColumns
* histogram is computed against these columns
* @tparam T
* type of histogram thresholds
* @return
* dataframe with aggregate and histogram columns
*/
def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
Histogram.of(ds, thresholds, valueColumn, aggregateColumns: _*)
/**
* Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy, this method sorts the
* data to improve partition file size. Small partitions will contain few files, large partitions contain more
* files. Partition ids are contained in a single partition file per `partitionBy` partition only. Rows within the
* partition files are also sorted, if partitionOrder is defined.
*
* Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is
* being cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588.
* That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method.
*
* Calling:
* {{{
* val unpersist = UnpersistHandle()
* val writer = df.writePartitionedBy(Seq("a"), Seq("b"), Seq("c"), Some(10), Seq($"a", concat($"b", $"c")), unpersist)
* writer.parquet("data.parquet")
* unpersist()
* }}}
*
* is equivalent to:
* {{{
* val cached =
* df.repartitionByRange(10, $"a", $"b")
* .sortWithinPartitions($"a", $"b", $"c")
* .cache
*
* val writer =
* cached
* .select($"a", concat($"b", $"c"))
* .write
* .partitionBy("a")
*
* writer.parquet("data.parquet")
*
* cached.unpersist
* }}}
*
* @param partitionColumns
* columns used for partitioning
* @param moreFileColumns
* columns where individual values are written to a single file
* @param moreFileOrder
* additional columns to sort partition files
* @param partitions
* optional number of partition files
* @param writtenProjection
* additional transformation to be applied before calling write
* @param unpersistHandle
* handle to unpersist internally created DataFrame after writing
* @return
* configured DataFrameWriter
*/
def writePartitionedBy(
partitionColumns: Seq[Column],
moreFileColumns: Seq[Column] = Seq.empty,
moreFileOrder: Seq[Column] = Seq.empty,
partitions: Option[Int] = None,
writtenProjection: Option[Seq[Column]] = None,
unpersistHandle: Option[UnpersistHandle] = None
): DataFrameWriter[Row] = {
if (partitionColumns.isEmpty)
throw new IllegalArgumentException(s"partition columns must not be empty")
if (partitionColumns.exists(col => !col.isInstanceOf[ColumnName] && !col.expr.isInstanceOf[NamedExpression]))
throw new IllegalArgumentException(s"partition columns must be named: ${partitionColumns.mkString(",")}")
val requiresCaching = writePartitionedByRequiresCaching(ds)
(requiresCaching, unpersistHandle.isDefined) match {
case (true, false) =>
warning(
"Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, " +
"and 3.3 below 3.3.2 requires caching an intermediate DataFrame, " +
"which calling code has to unpersist once writing is done. " +
"Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. " +
"See https://issues.apache.org/jira/browse/SPARK-40588"
)
case (false, true) if !unpersistHandle.get.isInstanceOf[NoopUnpersistHandle] =>
info(
"UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " +
"partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above " +
"does not require caching intermediate DataFrame."
)
unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame)
case _ =>
}
// resolve partition column names
val partitionColumnNames = ds.select(partitionColumns: _*).queryExecution.analyzed.output.map(_.name)
val partitionColumnsMap = partitionColumnNames.zip(partitionColumns).toMap
val rangeColumns = partitionColumnNames.map(col) ++ moreFileColumns
val sortColumns = partitionColumnNames.map(col) ++ moreFileColumns ++ moreFileOrder
ds.toDF
.call(ds => partitionColumnsMap.foldLeft(ds) { case (ds, (name, col)) => ds.withColumn(name, col) })
.when(partitions.isEmpty)
.call(_.repartitionByRange(rangeColumns: _*))
.when(partitions.isDefined)
.call(_.repartitionByRange(partitions.get, rangeColumns: _*))
.sortWithinPartitions(sortColumns: _*)
.when(writtenProjection.isDefined)
.call(_.select(writtenProjection.get: _*))
.when(requiresCaching && unpersistHandle.isDefined)
.call(unpersistHandle.get.setDataFrame(_))
.write
.partitionBy(partitionColumnsMap.keys.toSeq: _*)
}
/**
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
* @see
* `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
* @note
* Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot
* exploit existing partitioning and ordering of this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
* ds.groupByKey[(String, String)]($"department", $"gender").flatMapGroups(...)
* }}}
*/
def groupByKey[K: Encoder](column: Column, columns: Column*): KeyValueGroupedDataset[K, V] =
ds.groupBy(column +: columns: _*).as[K, V]
/**
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
* @see
* `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
* @note
* Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot
* exploit existing partitioning and ordering of this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
* ds.groupByKey[(String, String)]($"department", $"gender").flatMapGroups(...)
* }}}
*/
def groupByKey[K: Encoder](column: String, columns: String*): KeyValueGroupedDataset[K, V] =
ds.groupBy(column, columns: _*).as[K, V]
/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
* groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
* ds.groupBySorted($"department")($"salery")
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
* @param cols
* grouping columns
* @param order
* sort columns
*/
def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, cols, order, None)
}
/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
* groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
* ds.groupBySorted(10)($"department")($"salery")
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
* @param partitions
* number of partitions
* @param cols
* grouping columns
* @param order
* sort columns
*/
def groupBySorted[K: Ordering: Encoder](
partitions: Int
)(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, cols, order, Some(partitions))
}
/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
* groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
* ds.groupByKeySorted(row => row.getInt(0), 10)(row => row.getInt(1))
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
* @param partitions
* number of partitions
* @param key
* grouping key
* @param order
* sort key
*/
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
order: V => O
): SortedGroupByDataset[K, V] =
groupByKeySorted(key, Some(partitions))(order)
/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
* groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
* ds.groupByKeySorted(row => row.getInt(0), 10)(row => row.getInt(1), true)
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
* @param partitions
* number of partitions
* @param key
* grouping key
* @param order
* sort key
* @param reverse
* sort reverse order
*/
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
order: V => O,
reverse: Boolean
): SortedGroupByDataset[K, V] =
groupByKeySorted(key, Some(partitions))(order, reverse)
/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
* groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
* ds.groupByKeySorted(row => row.getInt(0))(row => row.getInt(1), true)
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
* @param partitions
* optional number of partitions
* @param key
* grouping key
* @param order
* sort key
* @param reverse
* sort reverse order
*/
def groupByKeySorted[K: Ordering: Encoder, O: Encoder](
key: V => K,
partitions: Option[Int] = None
)(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, key, order, partitions, reverse)
}
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(order: Column*): DataFrame =
RowNumbers.withOrderColumns(order: _*).of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(rowNumberColumnName: String, order: Column*): DataFrame =
RowNumbers.withRowNumberColumnName(rowNumberColumnName).withOrderColumns(order).of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(storageLevel: StorageLevel, order: Column*): DataFrame =
RowNumbers.withStorageLevel(storageLevel).withOrderColumns(order).of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
RowNumbers.withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =
RowNumbers
.withRowNumberColumnName(rowNumberColumnName)
.withStorageLevel(storageLevel)
.withOrderColumns(order)
.of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
RowNumbers
.withRowNumberColumnName(rowNumberColumnName)
.withUnpersistHandle(unpersistHandle)
.withOrderColumns(order)
.of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
RowNumbers.withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)
/**
* Adds a global continuous row number starting at 1, after sorting rows by the given columns. When no columns are
* given, the existing order is used.
*
* Hence, the following examples are equivalent:
* {{{
* ds.withRowNumbers($"a".desc, $"b")
* ds.orderBy($"a".desc, $"b").withRowNumbers()
* }}}
*
* The column name of the column with the row numbers can be set via the `rowNumberColumnName` argument.
*
* To avoid some known issues optimizing the query plan, this function has to internally call
* `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached DataFrame can be
* set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`.
*
* That cached intermediate DataFrame can be un-persisted / un-cached as follows:
* {{{
* import uk.co.gresearch.spark.UnpersistHandle
*
* val unpersist = UnpersistHandle()
* ds.withRowNumbers(unpersist).show()
* unpersist()
* }}}
*
* @param rowNumberColumnName
* name of the row number column
* @param storageLevel
* storage level of the cached intermediate DataFrame
* @param unpersistHandle
* handle to un-persist intermediate DataFrame
* @param order
* columns to order dataframe before assigning row numbers
* @return
* dataframe with row numbers
*/
def withRowNumbers(
rowNumberColumnName: String,
storageLevel: StorageLevel,
unpersistHandle: UnpersistHandle,
order: Column*
): DataFrame =
RowNumbers
.withRowNumberColumnName(rowNumberColumnName)
.withStorageLevel(storageLevel)
.withUnpersistHandle(unpersistHandle)
.withOrderColumns(order)
.of(ds)
}
/**
* Class to extend a Spark Dataframe.
*
* @param df
* dataframe
*/
@deprecated("Implicit class ExtendedDataframe is deprecated, please recompile your source code.", since = "2.9.0")
class ExtendedDataframe(df: DataFrame) extends ExtendedDataset[Row](df, df.encoder)
/**
* Class to extend a Spark Dataframe.
*
* @param df
* dataframe
*/
def ExtendedDataframe(df: DataFrame): ExtendedDataframe = new ExtendedDataframe(df)
/**
* Implicit class to extend a Spark Dataframe, which is a Dataset[Row].
*
* @param df
* dataframe
*/
implicit class ExtendedDataframeV2(df: DataFrame) extends ExtendedDatasetV2[Row](df)
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/parquet/ParquetMetaDataUtil.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.parquet
import org.apache.parquet.crypto.ParquetCryptoRuntimeException
import org.apache.parquet.hadoop.Footer
import org.apache.parquet.hadoop.metadata.{BlockMetaData, ColumnChunkMetaData, FileMetaData}
import org.apache.parquet.schema.PrimitiveType
import scala.reflect.{ClassTag, classTag}
import scala.util.Try
import scala.collection.convert.ImplicitConversions.`iterable AsScalaIterable`
private trait MethodGuard {
def isSupported[T: ClassTag](methodName: String): Boolean = {
Try(classTag[T].runtimeClass.getMethod(methodName)).isSuccess
}
def guard[T, R](supported: Boolean)(f: T => R): T => Option[R] =
guardOption(supported)(t => Some(f(t)))
def guardOption[T, R](supported: Boolean)(f: T => Option[R]): T => Option[R] =
if (supported) { (v: T) =>
f(v)
} else { (_: T) =>
None
}
}
/**
* Guard access to possibly encrypted and inaccessible metadata of a footer.
* - If footer is encrypted while we have no decryption keys, metadata values are None.
* - If footer is known not to be encrypted, metadata values are Some.
* - If we don't know whether the footer is encrypted, we access some metadata that we could not read if encrypted to
* determine the encryption state of the footer.
*/
private case class FooterGuard(footer: Footer) {
lazy val isSafe: Boolean = {
// having a decryptor tells us this file is expected to be decryptable
Option(footer.getParquetMetadata.getFileMetaData.getFileDecryptor)
// otherwise, when we have an unencrypted file, we are also safe to access f
.orElse(
ParquetMetaDataUtil
.getEncryptionType(footer.getParquetMetadata.getFileMetaData)
.filter(_ == "UNENCRYPTED")
)
// turn to Some(true) if safe, None if unknown
.map(_ => true)
// otherwise, we access some metadata that if the footer is encrypted would fail
.orElse(
Some(
Try(footer.getParquetMetadata.getBlocks.headOption.map(_.getTotalByteSize))
// get hold of the possible exception
.toEither.swap.toOption
// no exception means safe, ignore exceptions other than ParquetCryptoRuntimeException
.exists(!_.isInstanceOf[ParquetCryptoRuntimeException])
)
)
// now is Some(true) or Some(false)
.get
}
private[parquet] def apply[T](f: => T): Option[T] = {
if (isSafe) { Some(f) }
else { None }
}
}
private[parquet] object ParquetMetaDataUtil extends MethodGuard {
lazy val getEncryptionTypeIsSupported: Boolean =
isSupported[FileMetaData]("getEncryptionType")
lazy val getEncryptionType: FileMetaData => Option[String] =
guard(getEncryptionTypeIsSupported) { fileMetaData: FileMetaData =>
fileMetaData.getEncryptionType.name()
}
lazy val getLogicalTypeAnnotationIsSupported: Boolean =
isSupported[PrimitiveType]("getLogicalTypeAnnotation")
lazy val getLogicalTypeAnnotation: PrimitiveType => Option[String] =
guardOption(getLogicalTypeAnnotationIsSupported) { (primitive: PrimitiveType) =>
Option(primitive.getLogicalTypeAnnotation).map(_.toString)
}
lazy val getOrdinalIsSupported: Boolean =
isSupported[BlockMetaData]("getOrdinal")
lazy val getOrdinal: BlockMetaData => Option[Int] =
guard(getOrdinalIsSupported) { (block: BlockMetaData) =>
block.getOrdinal
}
lazy val isEncryptedIsSupported: Boolean =
isSupported[ColumnChunkMetaData]("isEncrypted")
lazy val isEncrypted: ColumnChunkMetaData => Option[Boolean] =
guard(isEncryptedIsSupported) { (column: ColumnChunkMetaData) =>
column.isEncrypted
}
}
================================================
FILE: src/main/scala/uk/co/gresearch/spark/parquet/package.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
// hadoop and parquet dependencies provided by Spark
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.metadata.BlockMetaData
import org.apache.parquet.hadoop.{Footer, ParquetFileReader}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.FilePartition
import uk.co.gresearch._
import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsScalaMapConverter}
import scala.collection.convert.ImplicitConversions.`iterable AsScalaIterable`
package object parquet {
private def conf: Configuration = SparkContext.getOrCreate().hadoopConfiguration
/**
* Implicit class to extend a Spark DataFrameReader.
*
* @param reader
* data frame reader
*/
implicit class ExtendedDataFrameReader(reader: DataFrameReader) {
/**
* Read the metadata of Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as there are Parquet files, at most
* `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-file information:
* - filename (string): The file name
* - blocks (int): Number of blocks / RowGroups in the Parquet file
* - compressedBytes (long): Number of compressed bytes of all blocks
* - uncompressedBytes (long): Number of uncompressed bytes of all blocks
* - rows (long): Number of rows in the file
* - columns (int): Number of columns in the file
* - values (long): Number of values in the file
* - nulls (long): Number of null values in the file
* - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
* - schema (string): The schema
* - encryption (string): The encryption
* - keyValues (string-to-string map): Key-value data of the file
*
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetMetadata(paths: String*): DataFrame = parquetMetadata(None, paths)
/**
* Read the metadata of Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-file information:
* - filename (string): The file name
* - blocks (int): Number of blocks / RowGroups in the Parquet file
* - compressedBytes (long): Number of compressed bytes of all blocks
* - uncompressedBytes (long): Number of uncompressed bytes of all blocks
* - rows (long): Number of rows in the file
* - columns (int): Number of columns in the file
* - values (long): Number of values in the file
* - nulls (long): Number of null values in the file
* - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
* - schema (string): The schema
* - encryption (string): The encryption
* - keyValues (string-to-string map): Key-value data of the file
*
* @param parallelism
* number of partitions of returned DataFrame
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetMetadata(parallelism: Int, paths: String*): DataFrame = parquetMetadata(Some(parallelism), paths)
private def parquetMetadata(parallelism: Option[Int], paths: Seq[String]): DataFrame = {
val files = getFiles(parallelism, paths)
import files.sparkSession.implicits._
files
.flatMap { case (_, file) =>
readFooters(file).map { footer =>
val guard = FooterGuard(footer)
(
footer.getFile.toString,
footer.getParquetMetadata.getBlocks.size(),
guard { footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum },
guard { footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum },
footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum,
footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(),
guard {
footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum
},
// when all columns have statistics, count the null values
guard {
Option(
footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics)))
)
.filter(_.forall(_.isDefined))
.map(_.map(_.get.getNumNulls).sum)
},
footer.getParquetMetadata.getFileMetaData.getCreatedBy,
footer.getParquetMetadata.getFileMetaData.getSchema.toString,
ParquetMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData),
footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala,
)
}
}
.toDF(
"filename",
"blocks",
"compressedBytes",
"uncompressedBytes",
"rows",
"columns",
"values",
"nulls",
"createdBy",
"schema",
"encryption",
"keyValues"
)
}
/**
* Read the schema of Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as there are Parquet files, at most
* `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-file information:
* - filename (string): The Parquet file name
* - columnName (string): The column name
* - columnPath (string array): The column path
* - repetition (string): The repetition
* - type (string): The data type
* - length (int): The length of the type
* - originalType (string): The original type
* - isPrimitive (boolean: True if type is primitive
* - primitiveType (string: The primitive type
* - primitiveOrder (string: The order of the primitive type
* - maxDefinitionLevel (int): The max definition level
* - maxRepetitionLevel (int): The max repetition level
*
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetSchema(paths: String*): DataFrame = parquetSchema(None, paths)
/**
* Read the schema of Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-file information:
* - filename (string): The Parquet file name
* - columnName (string): The column name
* - columnPath (string array): The column path
* - repetition (string): The repetition
* - type (string): The data type
* - length (int): The length of the type
* - originalType (string): The original type
* - isPrimitive (boolean: True if type is primitive
* - primitiveType (string: The primitive type
* - primitiveOrder (string: The order of the primitive type
* - maxDefinitionLevel (int): The max definition level
* - maxRepetitionLevel (int): The max repetition level
*
* @param parallelism
* number of partitions of returned DataFrame
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetSchema(parallelism: Int, paths: String*): DataFrame = parquetSchema(Some(parallelism), paths)
private def parquetSchema(parallelism: Option[Int], paths: Seq[String]): DataFrame = {
val files = getFiles(parallelism, paths)
import files.sparkSession.implicits._
files
.flatMap { case (_, file) =>
readFooters(file).flatMap { footer =>
footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column =>
(
footer.getFile.toString,
Option(column.getPrimitiveType).map(_.getName),
column.getPath,
Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name),
Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name),
Option(column.getPrimitiveType).map(_.getTypeLength),
Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name),
Option(column.getPrimitiveType).flatMap(ParquetMetaDataUtil.getLogicalTypeAnnotation),
column.getPrimitiveType.isPrimitive,
Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name),
Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name),
column.getMaxDefinitionLevel,
column.getMaxRepetitionLevel,
)
}
}
}
.toDF(
"filename",
"columnName",
"columnPath",
"repetition",
"type",
"length",
"originalType",
"logicalType",
"isPrimitive",
"primitiveType",
"primitiveOrder",
"maxDefinitionLevel",
"maxRepetitionLevel",
)
}
/**
* Read the metadata of Parquet blocks into a Dataframe.
*
* The returned DataFrame has as many partitions as there are Parquet files, at most
* `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-block information:
* - filename (string): The file name
* - block (int): Block / RowGroup number starting at 1
* - blockStart (long): Start position of the block in the Parquet file
* - compressedBytes (long): Number of compressed bytes in block
* - uncompressedBytes (long): Number of uncompressed bytes in block
* - rows (long): Number of rows in block
* - columns (int): Number of columns in block
* - values (long): Number of values in block
* - nulls (long): Number of null values in block
*
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlocks(paths: String*): DataFrame = parquetBlocks(None, paths)
/**
* Read the metadata of Parquet blocks into a Dataframe.
*
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-block information:
* - filename (string): The file name
* - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
* - blockStart (long): Start position of the block in the Parquet file
* - compressedBytes (long): Number of compressed bytes in block
* - uncompressedBytes (long): Number of uncompressed bytes in block
* - rows (long): Number of rows in block
* - columns (int): Number of columns in block
* - values (long): Number of values in block
* - nulls (long): Number of null values in block
*
* @param parallelism
* number of partitions of returned DataFrame
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlocks(parallelism: Int, paths: String*): DataFrame = parquetBlocks(Some(parallelism), paths)
private def parquetBlocks(parallelism: Option[Int], paths: Seq[String]): DataFrame = {
val files = getFiles(parallelism, paths)
import files.sparkSession.implicits._
files
.flatMap { case (_, file) =>
readFooters(file).flatMap { footer =>
val guard = FooterGuard(footer)
footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) =>
(
footer.getFile.toString,
ParquetMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
block.getStartingPos,
guard { block.getCompressedSize },
block.getTotalByteSize,
block.getRowCount,
block.getColumns.asScala.size,
guard { block.getColumns.asScala.map(_.getValueCount).sum },
// when all columns have statistics, count the null values
guard {
Option(block.getColumns.asScala.map(c => Option(c.getStatistics)))
.filter(_.forall(_.isDefined))
.map(_.map(_.get.getNumNulls).sum)
},
)
}
}
}
.toDF(
"filename",
"block",
"blockStart",
"compressedBytes",
"uncompressedBytes",
"rows",
"columns",
"values",
"nulls"
)
}
/**
* Read the metadata of Parquet block columns into a Dataframe.
*
* The returned DataFrame has as many partitions as there are Parquet files, at most
* `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-block-column information:
* - filename (string): The file name
* - block (int): Block / RowGroup number starting at 1
* - column (string): Block / RowGroup column name
* - codec (string): The coded used to compress the block column values
* - type (string): The data type of the block column
* - encodings (string): Encodings of the block column
* - minValue (string): Minimum value of this column in this block
* - maxValue (string): Maximum value of this column in this block
* - columnStart (long): Start position of the block column in the Parquet file
* - compressedBytes (long): Number of compressed bytes of this block column
* - uncompressedBytes (long): Number of uncompressed bytes of this block column
* - values (long): Number of values in this block column
* - nulls (long): Number of null values in block
*
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlockColumns(paths: String*): DataFrame = parquetBlockColumns(None, paths)
/**
* Read the metadata of Parquet block columns into a Dataframe.
*
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-block-column information:
* - filename (string): The file name
* - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
* - column (string): Block / RowGroup column name
* - codec (string): The coded used to compress the block column values
* - type (string): The data type of the block column
* - encodings (string): Encodings of the block column
* - minValue (string): Minimum value of this column in this block
* - maxValue (string): Maximum value of this column in this block
* - columnStart (long): Start position of the block column in the Parquet file
* - compressedBytes (long): Number of compressed bytes of this block column
* - uncompressedBytes (long): Number of uncompressed bytes of this block column
* - values (long): Number of values in this block column
* - nulls (long): Number of null values in block
*
* @param parallelism
* number of partitions of returned DataFrame
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlockColumns(parallelism: Int, paths: String*): DataFrame = parquetBlockColumns(Some(parallelism), paths)
private def parquetBlockColumns(parallelism: Option[Int], paths: Seq[String]): DataFrame = {
val files = getFiles(parallelism, paths)
import files.sparkSession.implicits._
files
.flatMap { case (_, file) =>
readFooters(file).flatMap { footer =>
val guard = FooterGuard(footer)
footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) =>
block.getColumns.asScala.map { column =>
(
footer.getFile.toString,
ParquetMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
column.getPath.toSeq,
guard { column.getCodec.toString },
guard { column.getPrimitiveType.toString },
guard { column.getEncodings.asScala.toSeq.map(_.toString).sorted },
ParquetMetaDataUtil.isEncrypted(column),
guard { Option(column.getStatistics).map(_.minAsString) },
guard { Option(column.getStatistics).map(_.maxAsString) },
guard { column.getStartingPos },
guard { column.getTotalSize },
guard { column.getTotalUncompressedSize },
guard { column.getValueCount },
guard { Option(column.getStatistics).map(_.getNumNulls) },
)
}
}
}
}
.toDF(
"filename",
"block",
"column",
"codec",
"type",
"encodings",
"encrypted",
"minValue",
"maxValue",
"columnStart",
"compressedBytes",
"uncompressedBytes",
"values",
"nulls"
)
}
/**
* Read the metadata of how Spark partitions Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as there are Parquet files, at most
* `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-partition information:
* - partition (int): The Spark partition id
* - start (long): The start position of the partition
* - end (long): The end position of the partition
* - length (long): The length of the partition
* - blocks (int): The number of Parquet blocks / RowGroups in this partition
* - compressedBytes (long): The number of compressed bytes in this partition
* - uncompressedBytes (long): The number of uncompressed bytes in this partition
* - rows (long): The number of rows in this partition
* - columns (int): Number of columns in the file
* - values (long): The number of values in this partition
* - filename (string): The Parquet file name
* - fileLength (long): The length of the Parquet file
*
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Spark Parquet partition metadata
*/
@scala.annotation.varargs
def parquetPartitions(paths: String*): DataFrame = parquetPartitions(None, paths)
/**
* Read the metadata of how Spark partitions Parquet files into a Dataframe.
*
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-partition information:
* - partition (int): The Spark partition id
* - start (long): The start position of the partition
* - end (long): The end position of the partition
* - length (long): The length of the partition
* - blocks (int): The number of Parquet blocks / RowGroups in this partition
* - compressedBytes (long): The number of compressed bytes in this partition
* - uncompressedBytes (long): The number of uncompressed bytes in this partition
* - rows (long): The number of rows in this partition
* - columns (int): Number of columns in the file
* - values (long): The number of values in this partition
* - filename (string): The Parquet file name
* - fileLength (long): The length of the Parquet file
*
* @param parallelism
* number of partitions of returned DataFrame
* @param paths
* one or more paths to Parquet files or directories
* @return
* dataframe with Spark Parquet partition metadata
*/
@scala.annotation.varargs
def parquetPartitions(parallelism: Int, paths: String*): DataFrame = parquetPartitions(Some(parallelism), paths)
private def parquetPartitions(parallelism: Option[Int], paths: Seq[String]): DataFrame = {
val files = getFiles(parallelism, paths)
import files.sparkSession.implicits._
files
.flatMap { case (part, file) =>
readFooters(file)
.map(footer => (footer, getBlocks(footer, file.start, file.length)))
.map { case (footer, blocks) =>
(
part,
file.start,
file.start + file.length,
file.length,
blocks.size,
blocks.map(_.getCompressedSize).sum,
blocks.map(_.getTotalByteSize).sum,
blocks.map(_.getRowCount).sum,
blocks
.map(_.getColumns.map(_.getPath.mkString(".")).toSet)
.foldLeft(Set.empty[String])((left, right) => left.union(right))
.size,
blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum,
// when all columns have statistics, count the null values
Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics))))
.filter(_.forall(_.isDefined))
.map(_.map(_.get.getNumNulls).sum),
footer.getFile.toString,
file.fileSize,
)
}
}
.toDF(
"partition",
"start",
"end",
"length",
"blocks",
"compressedBytes",
"uncompressedBytes",
"rows",
"columns",
"values",
"nulls",
"filename",
"fileLength"
)
}
private def getFiles(parallelism: Option[Int], paths: Seq[String]): Dataset[(Int, SplitFile)] = {
val df = reader.parquet(paths: _*)
val parts = df.rdd.partitions
.flatMap(part =>
part
.asInstanceOf[FilePartition]
.files
.map(file => (part.index, SplitFile(file)))
)
.toSeq
.distinct
import df.sparkSession.implicits._
parts
.toDS()
.when(parallelism.isDefined)
.call(_.repartition(parallelism.get))
}
}
private def readFooters(file: SplitFile): Iterable[Footer] = {
val path = new Path(file.filePath)
val status = path.getFileSystem(conf).getFileStatus(path)
ParquetFileReader.readFooters(conf, status, false).asScala
}
private def getBlocks(footer: Footer, start: Long, length: Long): Seq[BlockMetaData] = {
footer.getParquetMetadata.getBlocks.asScala
.map(block => (block, block.getStartingPos + block.getCompressedSize / 2))
.filter { case (_, midBlock) => start <= midBlock && midBlock < start + length }
.map(_._1)
.toSeq
}
}
================================================
FILE: src/main/scala-spark-3.2/uk/co/gresearch/spark/parquet/SplitFile.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.parquet
import org.apache.spark.sql.execution.datasources.PartitionedFile
private[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])
private[spark] object SplitFile {
def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath, file.start, file.length, None)
}
================================================
FILE: src/main/scala-spark-3.3/uk/co/gresearch/spark/parquet/SplitFile.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.parquet
import org.apache.spark.sql.execution.datasources.PartitionedFile
private[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])
private[spark] object SplitFile {
def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath, file.start, file.length, Some(file.fileSize))
}
================================================
FILE: src/main/scala-spark-3.5/org/apache/spark/sql/extension/package.scala
================================================
/*
* Copyright 2024 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.Expression
package object extension {
implicit class ColumnExtension(col: Column) {
// Column.expr exists in this Spark version and earlier
def sql: String = col.expr.sql
}
implicit class ExpressionExtension(expr: Expression) {
def column: Column = new Column(expr)
}
}
================================================
FILE: src/main/scala-spark-3.5/uk/co/gresearch/spark/Backticks.scala
================================================
/*
* Copyright 2021 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import java.util.regex.Pattern
object Backticks {
// https://github.com/apache/spark/blob/523ff15/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala#L46
private val validIdentPattern = Pattern.compile("^[a-zA-Z_][a-zA-Z0-9_]*")
/**
* Detects if column name part requires quoting.
* https://github.com/apache/spark/blob/523ff15/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala#L48
*/
private def needQuote(part: String): Boolean = {
!validIdentPattern.matcher(part).matches()
}
/**
* Encloses the given strings with backticks (backquotes) if needed.
*
* Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore,
* and contain only letters, numbers and underscores.
*
* Multiple strings will be enclosed individually and concatenated with dots (`.`).
*
* This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.
*
* Examples:
* {{{
* col("a.column") // this references the field "column" of column "a"
* col("`a.column`") // this reference the column with the name "a.column"
* col(Backticks.column_name("column")) // produces "column"
* col(Backticks.column_name("a.column")) // produces "`a.column`"
* col(Backticks.column_name("a column")) // produces "`a column`"
* col(Backticks.column_name("`a.column`")) // produces "`a.column`"
* col(Backticks.column_name("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
* @param string
* a string
* @param strings
* more strings
* @return
*/
@scala.annotation.varargs
def column_name(string: String, strings: String*): String =
(string +: strings)
.map(s => if (needQuote(s)) s"`${s.replace("`", "``")}`" else s)
.mkString(".")
}
================================================
FILE: src/main/scala-spark-4.0/org/apache/spark/sql/extension/package.scala
================================================
/*
* Copyright 2024 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.classic.ExpressionUtils.{column => toColumn, expression}
package object extension {
implicit class ColumnExtension(col: Column) {
def expr: Expression = expression(col)
def sql: String = col.node.sql
}
implicit class ExpressionExtension(expr: Expression) {
def column: Column = toColumn(expr)
}
}
================================================
FILE: src/main/scala-spark-4.0/uk/co/gresearch/spark/Backticks.scala
================================================
/*
* Copyright 2021 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark
import org.apache.spark.sql.catalyst.util.QuotingUtils
object Backticks {
/**
* Encloses the given strings with backticks (backquotes) if needed.
*
* Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore,
* and contain only letters, numbers and underscores.
*
* Multiple strings will be enclosed individually and concatenated with dots (`.`).
*
* This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.
*
* Examples:
* {{{
* col("a.column") // this references the field "column" of column "a"
* col("`a.column`") // this reference the column with the name "a.column"
* col(Backticks.column_name("column")) // produces "column"
* col(Backticks.column_name("a.column")) // produces "`a.column`"
* col(Backticks.column_name("a column")) // produces "`a column`"
* col(Backticks.column_name("`a.column`")) // produces "`a.column`"
* col(Backticks.column_name("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
* @param string
* a string
* @param strings
* more strings
* @return
*/
@scala.annotation.varargs
def column_name(string: String, strings: String*): String =
QuotingUtils.quoted(Array.from(string +: strings))
}
================================================
FILE: src/main/scala-spark-4.0/uk/co/gresearch/spark/parquet/SplitFile.scala
================================================
/*
* Copyright 2023 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.parquet
import org.apache.spark.sql.execution.datasources.PartitionedFile
private[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])
private[spark] object SplitFile {
def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath.toString, file.start, file.length, Some(file.fileSize))
}
================================================
FILE: src/test/java/uk/co/gresearch/test/SparkJavaTests.java
================================================
/*
* Copyright 2021 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// these tests are deliberately located outside uk.co.gresearch.spark to show how imports look for Java
package uk.co.gresearch.test;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
import org.apache.spark.sql.execution.CacheManager;
import org.apache.spark.storage.StorageLevel;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import uk.co.gresearch.spark.Backticks;
import uk.co.gresearch.spark.Histogram;
import uk.co.gresearch.spark.RowNumbers;
import uk.co.gresearch.spark.UnpersistHandle;
import uk.co.gresearch.spark.diff.JavaValue;
import java.util.Arrays;
import java.util.List;
public class SparkJavaTests {
private static SparkSession spark;
private static Dataset dataset;
@BeforeClass
public static void beforeClass() {
spark = SparkSession
.builder()
.master("local[*]")
.config(new SparkConf().set("fs.defaultFS", "file:///"))
.appName("Diff Java Suite")
.getOrCreate();
JavaValue valueOne = new JavaValue(1, "one", 1.0);
JavaValue valueTwo = new JavaValue(2, "two", 2.0);
JavaValue valueThree = new JavaValue(3, "three", 3.0);
Encoder encoder = Encoders.bean(JavaValue.class);
dataset = spark.createDataset(Arrays.asList(valueOne, valueTwo, valueThree), encoder);
}
@Test
public void testBackticks() {
Assert.assertEquals("col", Backticks.column_name("col"));
Assert.assertEquals("`a.col`", Backticks.column_name("a.col"));
Assert.assertEquals("a.col", Backticks.column_name("a", "col"));
Assert.assertEquals("some.more.columns", Backticks.column_name("some", "more", "columns"));
Assert.assertEquals("some.`more.columns`", Backticks.column_name("some", "more.columns"));
Assert.assertEquals("some.more.dotted.columns", Backticks.column_name("some", "more", "dotted", "columns"));
Assert.assertEquals("some.more.`dotted.columns`", Backticks.column_name("some", "more", "dotted.columns"));
}
@Test
public void testHistogram() {
Dataset histogram = Histogram.of(dataset, Arrays.asList(0, 1, 2), new Column("id"));
List expected = Arrays.asList(RowFactory.create(0, 1, 1, 1));
Assert.assertEquals(expected, histogram.collectAsList());
}
@Test
public void testHistogramWithAggColumn() {
Dataset histogram = Histogram.of(dataset, Arrays.asList(0, 1, 2), new Column("id"), new Column("label"));
List expected = Arrays.asList(
RowFactory.create("one", 0, 1, 0, 0),
RowFactory.create("three", 0, 0, 0, 1),
RowFactory.create("two", 0, 0, 1, 0)
);
Assert.assertEquals(expected, histogram.sort("label").collectAsList());
}
@Test
public void testRowNumbers() {
Dataset withRowNumbers = RowNumbers.of(dataset);
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 1),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 3)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
}
@Test
public void testRowNumbersOrderOneColumn() {
Dataset withRowNumbers = RowNumbers.withOrderColumns(dataset.col("id").desc()).of(dataset);
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 3),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 1)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
}
@Test
public void testRowNumbersOrderTwoColumns() {
Dataset withRowNumbers = RowNumbers.withOrderColumns(dataset.col("id"), dataset.col("label")).of(dataset);
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 1),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 3)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
}
@Test
public void testRowNumbersOrderDesc() {
Dataset withRowNumbers = RowNumbers.withOrderColumns(dataset.col("id").desc()).of(dataset);
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 3),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 1)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
}
@Test
public void testRowNumbersUnpersist() {
CacheManager cacheManager = SparkJavaTests.spark.sharedState().cacheManager();
cacheManager.clearCache();
Assert.assertTrue(cacheManager.isEmpty());
UnpersistHandle unpersist = new UnpersistHandle();
Dataset withRowNumbers = RowNumbers.withUnpersistHandle(unpersist).of(dataset);
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 1),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 3)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
Assert.assertFalse(cacheManager.isEmpty());
unpersist.apply(true);
Assert.assertTrue(cacheManager.isEmpty());
}
@Test
public void testRowNumbersStorageLevelAndUnpersist() {
CacheManager cacheManager = SparkJavaTests.spark.sharedState().cacheManager();
cacheManager.clearCache();
Assert.assertTrue(cacheManager.isEmpty());
UnpersistHandle unpersist = new UnpersistHandle();
RowNumbers.withStorageLevel(StorageLevel.MEMORY_ONLY()).withUnpersistHandle(unpersist).of(dataset);
Assert.assertFalse(cacheManager.isEmpty());
unpersist.apply(true);
Assert.assertTrue(cacheManager.isEmpty());
}
@Test
public void testRowNumbersColumnName() {
Dataset withRowNumbers = RowNumbers.withRowNumberColumnName("row").of(dataset);
Assert.assertEquals(Arrays.asList("id", "label", "score", "row"), Arrays.asList(withRowNumbers.columns()));
List expected = Arrays.asList(
RowFactory.create(1, "one", 1.0, 1),
RowFactory.create(2, "two", 2.0, 2),
RowFactory.create(3, "three", 3.0, 3)
);
Assert.assertEquals(expected, withRowNumbers.orderBy("id").collectAsList());
}
@AfterClass
public static void afterClass() {
if (spark != null) {
spark.stop();
}
}
}
================================================
FILE: src/test/java/uk/co/gresearch/test/diff/DiffJavaTests.java
================================================
/*
* Copyright 2021 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package uk.co.gresearch.spark.diff;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import scala.Tuple3;
import scala.math.Equiv;
import uk.co.gresearch.spark.diff.comparator.DiffComparator;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static java.lang.Math.abs;
public class DiffJavaTests {
private static SparkSession spark;
private static Dataset left;
private static Dataset right;
@BeforeClass
public static void beforeClass() {
spark = SparkSession
.builder()
.master("local[*]")
.config(new SparkConf().set("fs.defaultFS", "file:///"))
.appName("Diff Java Suite")
.getOrCreate();
JavaValue valueOne = new JavaValue(1, "one", 1.0);
JavaValue valueTwo = new JavaValue(2, "two", 2.0);
JavaValue valueThree = new JavaValue(3, "three", 3.0);
JavaValue valueThreeScored = new JavaValue(3, "three", 3.1);
JavaValue valueFour = new JavaValue(4, "four", 4.0);
Encoder encoder = Encoders.bean(JavaValue.class);
left = spark.createDataset(Arrays.asList(valueOne, valueTwo, valueThree), encoder);
right = spark.createDataset(Arrays.asList(valueTwo, valueThreeScored, valueFour), encoder);
}
@Test
public void testDiff() {
Dataset diff = Diff.of(left.toDF(), right.toDF(), "id");
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", null, 1.0, null),
RowFactory.create("N", 2, "two", "two", 2.0, 2.0),
RowFactory.create("C", 3, "three", "three", 3.0, 3.1),
RowFactory.create("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffNoKey() {
Dataset diff = Diff.of(left, right);
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", 1.0),
RowFactory.create("N", 2, "two", 2.0),
RowFactory.create("D", 3, "three", 3.0),
RowFactory.create("I", 3, "three", 3.1),
RowFactory.create("I", 4, "four", 4.0)
);
Assert.assertEquals(expected, diff.sort("id", "diff").collectAsList());
}
@Test
public void testDiffSingleKey() {
Dataset diff = Diff.of(left, right, "id");
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", null, 1.0, null),
RowFactory.create("N", 2, "two", "two", 2.0, 2.0),
RowFactory.create("C", 3, "three", "three", 3.0, 3.1),
RowFactory.create("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffMultipleKeys() {
Dataset diff = Diff.of(left, right, "id", "label");
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", 1.0, null),
RowFactory.create("N", 2, "two", 2.0, 2.0),
RowFactory.create("C", 3, "three", 3.0, 3.1),
RowFactory.create("I", 4, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffIgnoredColumn() {
Dataset diff = Diff.of(left, right, Collections.singletonList("id"), Collections.singletonList("score"));
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", null, 1.0, null),
RowFactory.create("N", 2, "two", "two", 2.0, 2.0),
RowFactory.create("N", 3, "three", "three", 3.0, 3.1),
RowFactory.create("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffAs() {
Encoder encoder = Encoders.bean(JavaValueAs.class);
Dataset diff = Diff.ofAs(left.toDF(), right.toDF(), encoder, "id");
List expected = Arrays.asList(
new JavaValueAs("D", 1, "one", null, 1.0, null),
new JavaValueAs("N", 2, "two", "two", 2.0, 2.0),
new JavaValueAs("C", 3, "three", "three", 3.0, 3.1),
new JavaValueAs("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffOfWith() {
Dataset> diff = Diff.ofWith(left, right, "id");
List> expected = Arrays.asList(
new Tuple3<>("D", new JavaValue(1, "one", 1.0), null),
new Tuple3<>("N", new JavaValue(2, "two", 2.0), new JavaValue(2, "two", 2.0)),
new Tuple3<>("C", new JavaValue(3, "three", 3.0), new JavaValue(3, "three", 3.1)),
new Tuple3<>("I", null, new JavaValue(4, "four", 4.0))
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDiffer() {
DiffOptions options = new DiffOptions();
Differ differ = new Differ(options);
Dataset diff = differ.diff(left, right, "id");
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", null, 1.0, null),
RowFactory.create("N", 2, "two", "two", 2.0, 2.0),
RowFactory.create("C", 3, "three", "three", 3.0, 3.1),
RowFactory.create("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
}
@Test
public void testDifferWithIgnored() {
DiffOptions options = new DiffOptions();
Differ differ = new Differ(options);
Dataset diff = differ.diff(left, right, Collections.singletonList("id"), Collections.singletonList("score"));
List expected = Arrays.asList(
RowFactory.create("D", 1, "one", null, 1.0, null),
RowFactory.create("N", 2, "two", "two", 2.0, 2.0),
RowFactory.create("N", 3, "three", "three", 3.0, 3.1),
RowFactory.create("I", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
List columns = Arrays.asList(diff.schema().fieldNames());
Assert.assertEquals(Arrays.asList("diff", "id", "left_label", "right_label", "left_score", "right_score"), columns);
}
@Test
public void testDiffWithOptions() {
DiffOptions options = new DiffOptions(
"action",
"before", "after",
"+", "~", "-", "=",
scala.Option.apply(null),
DiffMode.ColumnByColumn(),
false
);
Differ differ = new Differ(options);
Dataset diff = differ.diff(left, right, "id");
List expected = Arrays.asList(
RowFactory.create("-", 1, "one", null, 1.0, null),
RowFactory.create("=", 2, "two", "two", 2.0, 2.0),
RowFactory.create("~", 3, "three", "three", 3.0, 3.1),
RowFactory.create("+", 4, null, "four", null, 4.0)
);
Assert.assertEquals(expected, diff.sort("id").collectAsList());
List names = Arrays.asList(diff.schema().fieldNames());
Assert.assertEquals(Arrays.asList("action", "id", "before_label", "after_label", "before_score", "after_score"), names);
}
@Test
public void testDiffWithComparators() {
DiffComparator comparator = DiffComparators.epsilon(0.100000001).asInclusive().asAbsolute();
testDiffWithComparator(new DiffOptions().withComparator(comparator, DataTypes.DoubleType));
testDiffWithComparator(new DiffOptions().withComparator(comparator, "score"));
Equiv equivDouble = (Double x, Double y) -> x == null && y == null || x != null && y != null &&
abs(x - y) <= 0.1000000001;
testDiffWithComparator(new DiffOptions().withComparator(equivDouble, Encoders.DOUBLE()));
testDiffWithComparator(new DiffOptions().withComparator(equivDouble, Encoders.DOUBLE(), "score"));
testDiffWithComparator(new DiffOptions().withComparator(equivDouble, DataTypes.DoubleType));
Equiv