Repository: dfdx/Spark.jl Branch: main Commit: 62e68c7dea7d Files: 50 Total size: 108.9 KB Directory structure: gitextract_jhdue1jq/ ├── .editorconfig ├── .github/ │ └── workflows/ │ ├── TagBot.yml │ ├── docs.yml │ └── test.yml ├── .gitignore ├── LICENSE.md ├── Project.toml ├── README.md ├── deps/ │ └── build.jl ├── docs/ │ ├── .gitignore │ ├── Project.toml │ ├── localdocs.sh │ ├── make.jl │ └── src/ │ ├── api.md │ ├── index.md │ ├── sql.md │ └── streaming.md ├── examples/ │ ├── InstallJuliaEMR.sh │ ├── InstallJuliaHDI.sh │ └── SparkSubmitJulia.scala ├── jvm/ │ └── sparkjl/ │ ├── dependency-reduced-pom.xml │ ├── old_src/ │ │ ├── InputIterator.scala │ │ ├── JuliaRDD.scala │ │ ├── JuliaRunner.scala │ │ ├── OutputThread.scala │ │ ├── RDDUtils.scala │ │ └── StreamUtils.scala │ └── pom.xml ├── src/ │ ├── Spark.jl │ ├── chainable.jl │ ├── column.jl │ ├── compiler.jl │ ├── convert.jl │ ├── core.jl │ ├── dataframe.jl │ ├── defs.jl │ ├── init.jl │ ├── io.jl │ ├── row.jl │ ├── session.jl │ ├── streaming.jl │ ├── struct.jl │ └── window.jl └── test/ ├── data/ │ ├── people.json │ └── people2.json ├── runtests.jl ├── test_chainable.jl ├── test_compiler.jl ├── test_convert.jl └── test_sql.jl ================================================ FILE CONTENTS ================================================ ================================================ FILE: .editorconfig ================================================ ================================================ FILE: .github/workflows/TagBot.yml ================================================ name: TagBot on: issue_comment: types: - created workflow_dispatch: inputs: lookback: default: 3 permissions: contents: write jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' runs-on: ubuntu-latest steps: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} ssh: ${{ secrets.DOCUMENTER_KEY }} ================================================ FILE: .github/workflows/docs.yml ================================================ name: Documentation on: push: branches: - main tags: '*' pull_request: jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@latest with: version: '1.7' - name: Install dependencies run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.build()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key run: julia --project=docs/ docs/make.jl ================================================ FILE: .github/workflows/test.yml ================================================ name: Test on: push: branches: - main pull_request: branches: - main jobs: build: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: version: - '1.6' - '1.7' os: - ubuntu-latest arch: - x64 steps: - uses: actions/checkout@v2 - name: Set up JDK uses: actions/setup-java@v2 with: java-version: '8' distribution: 'adopt' - uses: julia-actions/setup-julia@latest with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest - uses: julia-actions/julia-uploadcodecov@latest env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} ================================================ FILE: .gitignore ================================================ *.jl.cov *.jl.mem *~ .idea/ .vscode/ target/ project/ *.class *.jar .juliahistory *.iml *.log nohup.out docs/build docs/site .DS_Store deps/hadoop Manifest.toml # hidden files _* ================================================ FILE: LICENSE.md ================================================ The Spark.jl package is licensed under the MIT "Expat" License: > Copyright (c) 2015: dfdx. > > Permission is hereby granted, free of charge, to any person obtaining > a copy of this software and associated documentation files (the > "Software"), to deal in the Software without restriction, including > without limitation the rights to use, copy, modify, merge, publish, > distribute, sublicense, and/or sell copies of the Software, and to > permit persons to whom the Software is furnished to do so, subject to > the following conditions: > > The above copyright notice and this permission notice shall be > included in all copies or substantial portions of the Software. > > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, > EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF > MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. > IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY > CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, > TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE > SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Project.toml ================================================ name = "Spark" uuid = "e3819d11-95af-5eea-9727-70c091663a01" version = "0.6.1" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" JavaCall = "494afd89-becb-516b-aafa-70d2670c0337" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841" [compat] IteratorInterfaceExtensions = "1" JavaCall = "0.7, 0.8" Reexport = "1.2" TableTraits = "1" Umlaut = "0.2" julia = "1.6" [extras] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "DataFrames"] ================================================ FILE: README.md ================================================ # Spark.jl A Julia interface to Apache Spark™ | **Latest Version** | **Documentation** | **PackageEvaluator** | **Build Status** | |:------------------:|:-----------------:|:--------------------:|:----------------:| | [![][version-img]][version-url] | [![][docs-latest-img]][docs-latest-url] | [![PkgEval][pkgeval-img]][pkgeval-url] | [![][gh-test-img]][gh-test-url] | Spark.jl provides an interface to Apache Spark™ platform, including SQL / DataFrame and Structured Streaming. It closely follows the PySpark API, making it easy to translate existing Python code to Julia. Spark.jl supports multiple cluster types (in client mode), and can be considered as an analogue to PySpark or RSpark within the Julia ecosystem. It supports running within on-premise installations, as well as hosted instance such as Amazon EMR and Azure HDInsight. **[Documentation][docs-latest-url]** ## Trademarks Apache®, [Apache Spark and Spark](http://spark.apache.org) are registered trademarks, or trademarks of the [Apache Software Foundation](http://www.apache.org/) in the United States and/or other countries. [docs-latest-img]: https://img.shields.io/badge/docs-latest-blue.svg [docs-latest-url]: http://dfdx.github.io/Spark.jl/dev [gh-test-img]: https://github.com/dfdx/Spark.jl/actions/workflows/test.yml/badge.svg [gh-test-url]: https://github.com/dfdx/Spark.jl/actions/workflows/test.yml [codecov-img]: https://codecov.io/gh/dfdx/Spark.jl/branch/master/graph/badge.svg [codecov-url]: https://codecov.io/gh/dfdx/Spark.jl [issues-url]: https://github.com/dfdx/Spark.jl/issues [pkgeval-img]: https://juliahub.com/docs/Spark/pkgeval.svg [pkgeval-url]: https://juliahub.com/ui/Packages/Spark/zpJEw [version-img]: https://juliahub.com/docs/Spark/version.svg [version-url]: https://juliahub.com/ui/Packages/Spark/zpJEw ================================================ FILE: deps/build.jl ================================================ mvn = Sys.iswindows() ? "mvn.cmd" : "mvn" which = Sys.iswindows() ? "where" : "which" try run(`$which $mvn`) catch error("Cannot find maven. Is it installed?") end SPARK_VERSION = get(ENV, "BUILD_SPARK_VERSION", "3.2.1") SCALA_VERSION = get(ENV, "BUILD_SCALA_VERSION", "2.13") SCALA_BINARY_VERSION = get(ENV, "BUILD_SCALA_VERSION", "2.13.6") cd(joinpath(dirname(@__DIR__), "jvm/sparkjl")) do run(`$mvn clean package -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION -Dscala.binary.version=$SCALA_BINARY_VERSION`) end ================================================ FILE: docs/.gitignore ================================================ data/ ================================================ FILE: docs/Project.toml ================================================ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Spark = "e3819d11-95af-5eea-9727-70c091663a01" ================================================ FILE: docs/localdocs.sh ================================================ #!/bin/bash julia -e 'using LiveServer; serve(dir="build")' ================================================ FILE: docs/make.jl ================================================ using Documenter using Spark makedocs( sitename = "Spark", format = Documenter.HTML(), modules = [Spark], pages = Any[ "Introduction" => "index.md", "SQL / DataFrames" => "sql.md", "Structured Streaming" => "streaming.md", "API Reference" => "api.md" ], ) deploydocs( repo = "github.com/dfdx/Spark.jl.git", devbranch = "main", ) ================================================ FILE: docs/src/api.md ================================================ ```@meta CurrentModule = Spark ``` ```@docs SparkSessionBuilder SparkSession RuntimeConfig DataFrame GroupedData Column Row StructType StructField Window WindowSpec DataFrameReader DataFrameWriter DataStreamReader DataStreamWriter StreamingQuery @chainable DotChainer ``` ```@index ``` ================================================ FILE: docs/src/index.md ================================================ # Introduction ## Overview Spark.jl provides an interface to Apache Spark™ platform, including SQL / DataFrame and Structured Streaming. It closely follows the PySpark API, making it easy to translate existing Python code to Julia. Spark.jl supports multiple cluster types (in client mode), and can be considered as an analogue to PySpark or RSpark within the Julia ecosystem. It supports running within on-premise installations, as well as hosted instance such as Amazon EMR and Azure HDInsight. ### Installation Spark.jl requires at least JDK 8/11 and Maven to be installed and available in PATH. ```julia ] add Spark ``` To link against a specific version of Spark, also run: ```julia ENV["BUILD_SPARK_VERSION"] = "3.2.1" # version you need ] build Spark ``` ### Quick Example Note that most types in Spark.jl support dot notation for calling functions, e.g. `x.foo(y)` is expanded into `foo(x, y)`. ```@example using Spark spark = SparkSession.builder.appName("Main").master("local").getOrCreate() df = spark.createDataFrame([["Alice", 19], ["Bob", 23]], "name string, age long") rows = df.select(Column("age") + 1).collect() for row in rows println(row[1]) end ``` ### Cluster Types This package supports multiple cluster types (in client mode): `local`, `standalone`, `mesos` and `yarn`. The location of the cluster (in case of mesos or standalone) or the cluster type (in case of local or yarn) must be passed as a parameter `master` when creating a Spark context. For YARN based clusters, the cluster parameters are picked up from `spark-defaults.conf`, which must be accessible via a `SPARK_HOME` environment variable. ## Current Limitations * Jobs can be submitted from Julia process attached to the cluster in `client` deploy mode. `Cluster` mode is not fully supported, and it is uncertain if it is useful in the Julia context. * Since records are serialised between Java and Julia at the edges, the maximum size of a single row in an RDD is 2GB, due to Java array indices being limited to 32 bits. ## Trademarks Apache®, [Apache Spark and Spark](http://spark.apache.org) are registered trademarks, or trademarks of the [Apache Software Foundation](http://www.apache.org/) in the United States and/or other countries. ================================================ FILE: docs/src/sql.md ================================================ ```@meta CurrentModule = Spark ``` # SQL / DataFrames This is a quick introduction into the Spark.jl core functions. It closely follows the official [PySpark tutorial](https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html) and copies many examples verbatim. In most cases, PySpark docs should work for Spark.jl as is or with little adaptation. Spark.jl applications usually start by creating a `SparkSession`: ```@example using Spark spark = SparkSession.builder.appName("Main").master("local").getOrCreate() ``` Note that here we use dot notation to chain function invocations. This makes the code more concise and also mimics Python API, making translation of examples easier. The same example could also be written as: ```julia using Spark import Spark: appName, master, getOrCreate builder = SparkSession.builder builder = appName(builder, "Main") builder = master(builder, "local") spark = getOrCreate(builder) ``` See [`@chainable`](@ref) for the details of the dot notation. ## DataFrame Creation In simple cases, a Spark DataFrame can be created via `SparkSession.createDataFrame`. E.g. from a list of rows: ```@example df using Spark # hide spark = SparkSession.builder.getOrCreate() # hide using Dates df = spark.createDataFrame([ Row(a=1, b=2.0, c="string1", d=Date(2000, 1, 1), e=DateTime(2000, 1, 1, 12, 0)), Row(a=2, b=3.0, c="string2", d=Date(2000, 2, 1), e=DateTime(2000, 1, 2, 12, 0)), Row(a=4, b=5.0, c="string3", d=Date(2000, 3, 1), e=DateTime(2000, 1, 3, 12, 0)) ]) println(df) ``` Or using an explicit schema: ```@example df df = spark.createDataFrame([ [1, 2.0, "string1", Date(2000, 1, 1), DateTime(2000, 1, 1, 12, 0)], [2, 3.0, "string2", Date(2000, 2, 1), DateTime(2000, 1, 2, 12, 0)], [3, 4.0, "string3", Date(2000, 3, 1), DateTime(2000, 1, 3, 12, 0)] ], "a long, b double, c string, d date, e timestamp") println(df) ``` ## Viewing Data The top rows of a DataFrame can be displayed using `DataFrame.show()`: ```@example df df.show(1) ``` You can see the DataFrame’s schema and column names as follows: ```@example df df.columns() ``` ```@example df df.printSchema() ``` Show the summary of the DataFrame ```@example df df.select("a", "b", "c").describe().show() ``` `DataFrame.collect()` collects the distributed data to the driver side as the local data in Julia. Note that this can throw an out-of-memory error when the dataset is too large to fit in the driver side because it collects all the data from executors to the driver side. ```@example df df.collect() ``` In order to avoid throwing an out-of-memory exception, use `take()` or `tail()`. ```@example df df.take(1) ``` ## Selecting and Accessing Data Spark.jl `DataFrame` is lazily evaluated and simply selecting a column does not trigger the computation but it returns a `Column` instance. ```@example df df.a ``` In fact, most of column-wise operations return `Column`s. ```@example df typeof(df.c) == typeof(df.c.upper()) == typeof(df.c.isNull()) ``` These `Column`s can be used to select the columns from a `DataFrame`. For example, `select()` takes the `Column` instances that returns another `DataFrame`. ```@example df df.select(df.c).show() ``` Assign new Column instance. ```@example df df.withColumn("upper_c", df.c.upper()).show() ``` To select a subset of rows, use `filter()` (a.k.a. `where()`). ```@example df df.filter(df.a == 1).show() ``` ## Grouping Data Spark.jl `DataFrame` also provides a way of handling grouped data by using the common approach, split-apply-combine strategy. It groups the data by a certain condition applies a function to each group and then combines them back to the `DataFrame`. ```@example gdf using Spark # hide spark = SparkSession.builder.appName("Main").master("local").getOrCreate() # hide df = spark.createDataFrame([ ["red", "banana", 1, 10], ["blue", "banana", 2, 20], ["red", "carrot", 3, 30], ["blue", "grape", 4, 40], ["red", "carrot", 5, 50], ["black", "carrot", 6, 60], ["red", "banana", 7, 70], ["red", "grape", 8, 80]], ["color string", "fruit string", "v1 long", "v2 long"]) df.show() ``` Grouping and then applying the `avg()` function to the resulting groups. ```@example gdf df.groupby("color").avg().show() ``` ## Getting Data in/out Spark.jl can read and write a variety of data formats. Here's a few examples. ### CSV ```@example gdf df.write.option("header", true).csv("data/fruits.csv") spark.read.option("header", true).csv("data/fruits.csv") ``` ### Parquet ```@example gdf df.write.parquet("data/fruits.parquet") spark.read.parquet("data/fruits.parquet") ``` ### ORC ```@example gdf df.write.orc("data/fruits.orc") spark.read.orc("data/fruits.orc") ``` ## Working with SQL `DataFrame` and Spark SQL share the same execution engine so they can be interchangeably used seamlessly. For example, you can register the `DataFrame` as a table and run a SQL easily as below: ```@example gdf df.createOrReplaceTempView("tableA") spark.sql("SELECT count(*) from tableA").show() ``` ```@example gdf spark.sql("SELECT fruit, sum(v1) as s FROM tableA GROUP BY fruit ORDER BY s").show() ``` ================================================ FILE: docs/src/streaming.md ================================================ # Structured Streaming Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. In this tutorial, we explore basic API of the Structured Streaming in Spark.jl. For a general introduction into the topic and more advanced examples follow the [official guide](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html) and adapt Python snippets. Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. We will use Netcat to send this data: ``` nc -lk 9999 ``` As usually, we start by creating a SparkSession: ```@example basic using Spark spark = SparkSession. builder. master("local"). appName("StructuredNetworkWordCount"). getOrCreate() ``` Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts. ```@example basic # Create DataFrame representing the stream of input lines from connection to localhost:9999 lines = spark. readStream. format("socket"). option("host", "localhost"). option("port", 9999). load() # Split the lines into words words = lines.select( lines.value.split(" ").explode().alias("word") ) # Generate running word count wordCounts = words.groupBy("word").count() ``` This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - `split` and `explode`, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode("complete"))` to the console every time they are updated. And then start the streaming computation using `start()`. ```julia query = wordCounts. writeStream. outputMode("complete"). format("console"). start() query.awaitTermination() ``` Now type a few lines in the Netcat terminal window and you should see output similar to this: ```julia julia> query.awaitTermination() ------------------------------------------- Batch: 0 ------------------------------------------- +----+-----+ |word|count| +----+-----+ +----+-----+ ------------------------------------------- Batch: 1 ------------------------------------------- +------------+-----+ | word|count| +------------+-----+ | was| 1| | for| 1| | beginning| 1| | Julia| 1| | designed| 1| | the| 1| | high| 1| | from| 1| |performance.| 1| +------------+-----+ ------------------------------------------- Batch: 2 ------------------------------------------- +------------+-----+ | word|count| +------------+-----+ | was| 1| | for| 1| | beginning| 1| | Julia| 2| | is| 1| | designed| 1| | the| 1| | high| 1| | from| 1| | typed| 1| |performance.| 1| | dynamically| 1| +------------+-----+ ``` ================================================ FILE: examples/InstallJuliaEMR.sh ================================================ #!/bin/bash ## This is a bootstrap action for installing Julia and Spark.jl on an Amazon EMR cluster. ## It's been tested with Julia 1.6.2 and EMR 5.33 and performs the following actions: ## 1. Installs Julia 1.6.2 and Maven 3.8.1 ## 2. Configures the "hadoop" user's startup.jl to load Spark/Hadoop dependencies ## 3. Creates a shared package directory in which to install Spark.jl ## 4. Installs v0.5.1 of Spark.jl for the necessary Spark/Scala versions ## ## You can run this script manually on every node or upload it to S3 and run it as a bootstrap action. ## When creating the EMR cluster, set the "spark-default" configuration with the following JSON. ## Reference: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-spark-configure.html # # [ # { # "Classification": "spark-defaults", # "Properties": { # "spark.executorEnv.JULIA_HOME": "/usr/local/julia-1.6.2/bin", # "spark.executorEnv.JULIA_DEPOT_PATH": "/usr/local/share/julia/v1.6.2", # "spark.executorEnv.JULIA_VERSION": "v1.6.2" # } # } # ] export JULIA_VERSION="1.6.2" export JULIA_DL_URL="https://julialang-s3.julialang.org/bin/linux/x64/1.6/julia-1.6.2-linux-x86_64.tar.gz" # install julia curl -sL ${JULIA_DL_URL} | sudo tar -xz -C /usr/local/ JULIA_DIR=/usr/local/julia-${JULIA_VERSION} # install maven curl -s https://mirrors.sonic.net/apache/maven/maven-3/3.8.1/binaries/apache-maven-3.8.1-bin.tar.gz | sudo tar -xz -C /usr/local/ MAVEN_DIR=/usr/local/apache-maven-3.8.1 # Update the `hadoop` user's current and future path with Maven and Julia. # This allows us to download/install Spark.jl export PATH=${MAVEN_DIR}/bin:${JULIA_DIR}/bin:${PATH} echo "export PATH=${MAVEN_DIR}/bin:${JULIA_DIR}/bin:${PATH}" >> /home/hadoop/.bashrc # Create a shared package dir for the installation sudo mkdir -p /usr/local/share/julia/v${JULIA_VERSION} && \ sudo chown -R hadoop.hadoop /usr/local/share/julia/ && \ sudo chmod -R go+r /usr/local/share/julia/ # Create a config file that adds Spark environment variables # and adds the new package dir to the DEPOT_PATH. # This ensures that Spark.jl gets installed to a shared location. export TARGET_USER=hadoop export JULIA_CFG_DIR="/home/${TARGET_USER}/.julia/config" mkdir -p ${JULIA_CFG_DIR} && \ touch ${JULIA_CFG_DIR}/startup.jl && \ chown -R hadoop.hadoop /home/hadoop/.julia echo 'ENV["SPARK_HOME"] = "/usr/lib/spark/"' >> "${JULIA_CFG_DIR}/startup.jl" echo 'ENV["HADOOP_CONF_DIR"] = "/etc/hadoop/conf"' >> "${JULIA_CFG_DIR}/startup.jl" echo 'push!(DEPOT_PATH, "/usr/local/share/julia/v'${JULIA_VERSION}'")' >> "${JULIA_CFG_DIR}/startup.jl" # Install Spark.jl - we need to explicity define Spark/Scala versions here BUILD_SCALA_VERSION=2.11.12 \ BUILD_SPARK_VERSION=2.4.7 \ JULIA_COPY_STACKS=yes \ JULIA_DEPOT_PATH=/usr/local/share/julia/v${JULIA_VERSION} \ julia -e 'using Pkg;Pkg.add(Pkg.PackageSpec(;name="Spark", version="0.5.1"));using Spark;' ================================================ FILE: examples/InstallJuliaHDI.sh ================================================ #!/usr/bin/env bash # An example shell script that can be used on Azure HDInsight to install Julia to HDI Spark cluster # This script, or a derivative should be set as a script action when deploying an HDInsight cluster # install julia v0.6 curl -sL https://julialang-s3.julialang.org/bin/linux/x64/1.0/julia-1.0.5-linux-x86_64.tar.gz | sudo tar -xz -C /usr/local/ JULIA_HOME=/usr/local/julia-1.0.5/bin # install maven curl -s http://mirror.olnevhost.net/pub/apache/maven/binaries/apache-maven-3.2.2-bin.tar.gz | sudo tar -xz -C /usr/local/ export M2_HOME=/usr/local/apache-maven-3.2.2 export PATH=$M2_HOME/bin:$PATH # Create Directories export JULIA_DEPOT_PATH="/home/hadoop/.julia/" mkdir -p ${JULIA_DEPOT_PATH} # Set Environment variables for current session export PATH=${PATH}:${MVN_HOME}/bin:${JULIA_HOME}/bin export HOME="/root" echo "Installing Julia Packages in Julia Folder ${JULIA_DEPOT_PATH}" #Install Spark.jl $JULIA_HOME/julia -e 'using Pkg; Pkg.add("Spark");Pkg.build("Spark"); using Spark;' declare -a users=("spark" "yarn" "hadoop" "sshuser") #TODO: change accordingly SPARK_HOME=/usr/hdp/current/spark2-client echo "spark.executorEnv.JULIA_HOME ${JULIA_HOME}" >> ${SPARK_HOME}/conf/spark-defaults.conf echo "spark.executorEnv.JULIA_DEPOT_PATH ${JULIA_DEPOT_PATH}" >> ${SPARK_HOME}/conf/spark-defaults.conf echo "spark.executorEnv.JULIA_VERSION v1.0.5" >> ${SPARK_HOME}/conf/spark-defaults.conf for cusr in "${users[@]}"; do echo " Adding vars for ser ${cusr}" echo "" >> /home/${cusr}/.bashrc echo "export MVN_HOME=/usr/local/apache-maven-3.2.2" >> /home/${cusr}/.bashrc echo "export PATH=${PATH}:${MVN_HOME}/bin:${JULIA_HOME}" >> /home/${cusr}/.bashrc echo "export YARN_CONF_DIR=/etc/hadoop/conf" >> /home/${cusr}/.bashrc echo "export JULIA_HOME=${JULIA_HOME}" >> /home/${cusr}/.bashrc echo "export JULIA_DEPOT_PATH=${JULIA_DEPOT_PATH}" >> /home/${cusr}/.bashrc echo "source ${SPARK_HOME}/bin/load-spark-env.sh" >> /home/${cusr}/.bashrc # Set Package folder permissions setfacl -R -m u:${cusr}:rwx ${JULIA_DEPOT_PATH}; done ================================================ FILE: examples/SparkSubmitJulia.scala ================================================ /** * A simple scala class that can be used along with spark-submit to * submit a Julia script to be run in a spark cluster. E.g.: * * $ spark-submit --class org.julialang.juliaparallel.SparkSubmitJulia \ * --master yarn \ * --deploy-mode cluster \ * --driver-memory 4g \ * --executor-memory 2g \ * --executor-cores 1 \ * spark-julia_2.11-1.0.jar \ * /opt/julia/depot/helloworld.jl \ * /usr/local/julia/bin/julia \ * /opt/julia/depot * * To compile, use `src/main/scala/SparkSubmitJulia.scala` with a build.sbt like: * --------------------- * name := "Spark Submit Julia" * version := "1.0" * scalaVersion := "2.11.8" * libraryDependencies += "org.apache.spark" % "spark-sql_2.11" % "2.4.4" * --------------------- */ package org.julialang.juliaparallel import scala.sys.process._ import org.apache.spark.sql.SparkSession object SparkSubmitJulia { def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("Spark Submit Julia") .getOrCreate() val script = args(0) // e.g.: "/opt/julia/depot/helloworld.jl" val juliapath = args(1) // e.g.: "/usr/local/julia/bin/julia" val juliadepotpath = args(2) // e.g.: "/opt/julia/depot" val exitcode = Process(Seq(juliapath, script), None, "JULIA_DEPOT_PATH" -> juliadepotpath).! println(s"Completed with exitcode $exitcode") spark.stop() } } ================================================ FILE: jvm/sparkjl/dependency-reduced-pom.xml ================================================ 4.0.0 sparkjl sparkjl sparkjl 0.2 net.alchim31.maven scala-maven-plugin 4.6.1 maven-compiler-plugin 3.8.1 1.8 1.8 net.alchim31.maven scala-maven-plugin scala-compile-first process-resources add-source compile scala-test-compile process-test-resources testCompile maven-compiler-plugin compile compile 1.8 1.8 maven-shade-plugin 3.3.0 package shade META-INF/*.SF META-INF/*.DSA META-INF/*.RSA classworlds:classworlds junit:junit jmock:* *:xml-apis org.apache.maven:lib:tests log4j:log4j:jar: 2.13.6 [3.2.0,3.2.1] 1.11 64m UTF-8 2.13 512m UTF-8 ================================================ FILE: jvm/sparkjl/old_src/InputIterator.scala ================================================ package org.apache.spark.api.julia import java.io.{BufferedInputStream, DataInputStream, EOFException} import java.net.Socket import org.apache.spark.internal.Logging import org.apache.commons.compress.utils.Charsets import org.apache.spark._ /** * Iterator that connects to a Julia process and reads data back to JVM. * */ class InputIterator[T](context: TaskContext, worker: Socket, outputThread: OutputThread) extends Iterator[T] with Logging { val BUFFER_SIZE = 65536 val env = SparkEnv.get val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, BUFFER_SIZE)) override def next(): T = { val obj = _nextObj if (hasNext) { _nextObj = read() } obj } private def read(): T = { if (outputThread.exception.isDefined) { throw outputThread.exception.get } try { JuliaRDD.readValueFromStream(stream).asInstanceOf[T] } catch { case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) null.asInstanceOf[T] // exit silently case e: Exception if outputThread.exception.isDefined => logError("Julia worker exited unexpectedly (crashed)", e) logError("This may have been caused by a prior exception:", outputThread.exception.get) throw outputThread.exception.get case eof: EOFException => throw new SparkException("Julia worker exited unexpectedly (crashed)", eof) } } var _nextObj = read() override def hasNext: Boolean = _nextObj != null } ================================================ FILE: jvm/sparkjl/old_src/JuliaRDD.scala ================================================ package org.apache.spark.api.julia import java.io._ import java.net._ import sys.process.Process import java.nio.file.Paths import org.apache.commons.compress.utils.Charsets import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import scala.collection.JavaConversions._ import scala.language.existentials import scala.reflect.ClassTag class AbstractJuliaRDD[T:ClassTag]( @transient parent: RDD[_], command: Array[Byte] ) extends RDD[T](parent) { val preservePartitioning = true val reuseWorker = true override def getPartitions: Array[Partition] = firstParent.partitions // Note: needs to override in later versions of Spark // override def getNumPartitions: Int = firstParent.partitions.length override val partitioner: Option[Partitioner] = { if (preservePartitioning) firstParent.partitioner else None } override def compute(split: Partition, context: TaskContext): Iterator[T] = { val worker: Socket = JuliaRDD.createWorker() // Start a thread to feed the process input from our parent's iterator val outputThread = new OutputThread(context, firstParent.iterator(split, context), worker, command, split) outputThread.start() // Return an iterator that read lines from the process's stdout val resultIterator = new InputIterator[T](context, worker, outputThread) new InterruptibleIterator(context, resultIterator) } } class JuliaRDD(@transient parent: RDD[_],command: Array[Byte]) extends AbstractJuliaRDD[Any](parent, command) { def asJavaRDD(): JavaRDD[Any] = { JavaRDD.fromRDD(this) } } private object SpecialLengths { val END_OF_DATA_SECTION = -1 val JULIA_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 val END_OF_STREAM = -4 val NULL = -5 val PAIR_TUPLE = -6 val ARRAY_VALUE = -7 val ARRAY_END = -8 val INTEGER = -9 val STRING_START = -100 } object JuliaRDD extends Logging { def fromRDD[T](rdd: RDD[T], command: Array[Byte]): JuliaRDD = new JuliaRDD(rdd, command) def createWorker(): Socket = { var serverSocket: ServerSocket = null try { serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1).map(_.toByte))) // Create and start the worker val juliaHome = sys.env.get("JULIA_HOME").getOrElse("") val juliaVersion = sys.env.get("JULIA_VERSION").getOrElse("v0.7") val juliaCommand = Paths.get(juliaHome, "julia").toString() val juliaPkgDir = sys.env.get("JULIA_PKGDIR") match { case Some(i) => Paths.get(i, juliaVersion, "Spark").toString() case None => Process(juliaCommand + " -e println(dirname(dirname(Base.find_package(\"Spark\"))))").!!.trim } val pb = new ProcessBuilder(juliaCommand, Paths.get(juliaPkgDir, "src", "worker_runner.jl").toString()) pb.directory(new File(SparkFiles.getRootDirectory())) // val workerEnv = pb.environment() // workerEnv.putAll(envVars) val worker = pb.start() // Redirect worker stdout and stderr StreamUtils.redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) // Tell the worker our port val out = new OutputStreamWriter(worker.getOutputStream) out.write(serverSocket.getLocalPort + "\n") out.flush() // Wait for it to connect to our socket serverSocket.setSoTimeout(120000) try { val socket = serverSocket.accept() // workers.put(socket, worker) return socket } catch { case e: Exception => throw new SparkException("Julia worker did not connect back in time", e) } } finally { if (serverSocket != null) { serverSocket.close() } } null } def writeValueToStream[T](obj: Any, dataOut: DataOutputStream) { obj match { case arr: Array[Byte] => dataOut.writeInt(arr.length) dataOut.write(arr) case tup: Tuple2[Any, Any] => dataOut.writeInt(SpecialLengths.PAIR_TUPLE) writeValueToStream(tup._1, dataOut) writeValueToStream(tup._2, dataOut) case str: String => val arr = str.getBytes(Charsets.UTF_8) dataOut.writeInt(-arr.length + SpecialLengths.STRING_START) dataOut.write(arr) case jac: java.util.AbstractCollection[_] => writeValueToStream(jac.iterator, dataOut) case jit: java.util.Iterator[_] => while (jit.hasNext) { dataOut.writeInt(SpecialLengths.ARRAY_VALUE) writeValueToStream(jit.next(), dataOut) } dataOut.writeInt(SpecialLengths.ARRAY_END) case ita: Iterable[_] => writeValueToStream(ita.iterator, dataOut) case it: Iterator[_] => while (it.hasNext) { dataOut.writeInt(SpecialLengths.ARRAY_VALUE) writeValueToStream(it.next(), dataOut) } dataOut.writeInt(SpecialLengths.ARRAY_END) case x: Int => dataOut.writeInt(SpecialLengths.INTEGER) dataOut.writeLong(x) case x: java.lang.Long => dataOut.writeInt(SpecialLengths.INTEGER) dataOut.writeLong(x) case x: java.lang.Integer => dataOut.writeInt(SpecialLengths.INTEGER) dataOut.writeLong(x.longValue) case other => throw new SparkException("Unexpected element type " + other.getClass) } } def readValueFromStream(stream: DataInputStream) : Any = { var typeLength = stream.readInt() typeLength match { case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) obj case 0 => Array.empty[Byte] case SpecialLengths.PAIR_TUPLE => (readValueFromStream(stream), readValueFromStream(stream)) case SpecialLengths.JULIA_EXCEPTION_THROWN => // Signals that an exception has been thrown in julia val exLength = stream.readInt() val strlength = -exLength + SpecialLengths.STRING_START val obj = new Array[Byte](strlength) stream.readFully(obj) val str = new String(obj, Charsets.UTF_8) throw new Exception(str) case SpecialLengths.ARRAY_VALUE => val ab = new collection.mutable.ArrayBuffer[Any]() while(typeLength == SpecialLengths.ARRAY_VALUE) { ab += readValueFromStream(stream) typeLength = stream.readInt() } ab.toIterator case SpecialLengths.ARRAY_END => new Array[Any](0) case SpecialLengths.INTEGER => stream.readLong() case SpecialLengths.STRING_START => "" case length if length < SpecialLengths.STRING_START => val strlength = -length + SpecialLengths.STRING_START val obj = new Array[Byte](strlength) stream.readFully(obj) new String(obj, Charsets.UTF_8) case SpecialLengths.END_OF_DATA_SECTION => if (stream.readInt() == SpecialLengths.END_OF_STREAM) { null } else { throw new RuntimeException("Protocol error") } } } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Any] = { val file = new DataInputStream(new FileInputStream(filename)) try { val objs = new collection.mutable.ArrayBuffer[Any] try { while (true) { objs.append(readValueFromStream(file)) } } catch { case eof: EOFException => // No-op } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } finally { file.close() } } def cartesianSS(rdd1: JavaRDD[Any], rdd2: JavaRDD[Any]): JavaPairRDD[Any, Any] = { rdd1.cartesian(rdd2) } def collectToJulia(rdd: JavaRDD[Any]): Array[Byte] = { writeToByteArray[java.util.List[Any]](rdd.collect()) } def collectToJuliaItr(rdd: JavaRDD[Any]): java.util.List[Any] = { return rdd.collect() } def writeToByteArray[T](obj: Any): Array[Byte] = { val byteArrayOut = new ByteArrayOutputStream() val dataStream = new DataOutputStream(byteArrayOut) writeValueToStream(obj, dataStream) dataStream.flush() byteArrayOut.toByteArray() } } class JuliaPairRDD(@transient parent: RDD[_],command: Array[Byte]) extends AbstractJuliaRDD[(Any, Any)](parent, command) { def asJavaPairRDD(): JavaPairRDD[Any, Any] = { JavaPairRDD.fromRDD(this) } } object JuliaPairRDD extends Logging { def fromRDD[T](rdd: RDD[T], command: Array[Byte]): JuliaPairRDD = new JuliaPairRDD(rdd, command) def collectToJulia(rdd: JavaPairRDD[Any, Any]): Array[Byte] = { JuliaRDD.writeToByteArray[java.util.List[(Any, Any)]](rdd.collect()) } def collectToJuliaItr(rdd: JavaPairRDD[Any, Any]): java.util.List[(Any, Any)] = { return rdd.collect() } } ================================================ FILE: jvm/sparkjl/old_src/JuliaRunner.scala ================================================ package org.apache.spark.api.julia import scala.collection.JavaConversions._ /** * Class for execution of Julia scripts on a cluster. * WARNING: this class isn't used currently, will be utilized later */ object JuliaRunner { def main(args: Array[String]): Unit = { val juliaScript = args(0) val scriptArgs = args.slice(1, args.length) val pb = new ProcessBuilder(Seq("julia", juliaScript) ++ scriptArgs) val process = pb.start() StreamUtils.redirectStreamsToStderr(process.getInputStream, process.getErrorStream) val errorCode = process.waitFor() if (errorCode != 0) { throw new RuntimeException("Julia script exited with an error") } } } ================================================ FILE: jvm/sparkjl/old_src/OutputThread.scala ================================================ package org.apache.spark.api.julia import java.io.{DataOutputStream, BufferedOutputStream} import java.net.Socket import org.apache.spark.util.Utils import org.apache.spark.{TaskContext, Partition, SparkEnv} /** * The thread responsible for writing the data from the JuliaRDD's parent iterator to the * Julia process. */ class OutputThread(context: TaskContext, it: Iterator[Any], worker: Socket, command: Array[Byte], split: Partition) extends Thread(s"stdout writer for julia") { val BUFFER_SIZE = 65536 val env = SparkEnv.get @volatile private var _exception: Exception = null /** Contains the exception thrown while writing the parent iterator to the Julia process. */ def exception: Option[Exception] = Option(_exception) /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { assert(context.isCompleted) this.interrupt() } override def run(): Unit = Utils.logUncaughtExceptions { try { val stream = new BufferedOutputStream(worker.getOutputStream, BUFFER_SIZE) val dataOut = new DataOutputStream(stream) // partition index dataOut.writeInt(split.index) dataOut.flush() // serialized command: dataOut.writeInt(command.length) dataOut.write(command) dataOut.flush() // data values writeIteratorToStream(it, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => // FIXME: logDebug("Exception thrown after task completion (likely due to cleanup)", e) println("Exception thrown after task completion (likely due to cleanup)", e) if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } } // } finally { // // Release memory used by this thread for shuffles // // env.shuffleMemoryManager.releaseMemoryForThisThread() // env.shuffleMemoryManager.releaseMemoryForThisTask() // // Release memory used by this thread for unrolling blocks // // env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() // env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() // } } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { def write(obj: Any): Unit = { JuliaRDD.writeValueToStream(obj, dataOut) } iter.foreach(write) } } ================================================ FILE: jvm/sparkjl/old_src/RDDUtils.scala ================================================ package org.apache.spark.api.julia import org.apache.spark.internal.Logging import org.apache.spark.api.java.{JavaRDD, JavaPairRDD} object RDDUtils extends Logging { /** * Get number of partitions in the RDD */ def getNumPartitions(jrdd: JavaRDD[Any]): Int = jrdd.rdd.partitions.length def getNumPartitions(jrdd: JavaPairRDD[Any,Any]): Int = jrdd.rdd.partitions.length } ================================================ FILE: jvm/sparkjl/old_src/StreamUtils.scala ================================================ package org.apache.spark.api.julia import java.io.InputStream import org.apache.spark.internal.Logging import org.apache.spark.util.RedirectThread object StreamUtils extends Logging { /** * Redirect the given streams to our stderr in separate threads. */ def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) { try { new RedirectThread(stdout, System.err, "stdout reader for julia").start() new RedirectThread(stderr, System.err, "stderr reader for julia").start() } catch { case e: Exception => logError("Exception in redirecting streams", e) } } } ================================================ FILE: jvm/sparkjl/pom.xml ================================================ sparkjl sparkjl 4.0.0 sparkjl jar 0.2 UTF-8 UTF-8 1.11 2.13 2.13.6 [3.2.0,3.2.1] 64m 512m org.scala-lang scala-library ${scala.binary.version} org.apache.spark spark-core_${scala.version} ${spark.version} org.apache.hadoop hadoop-client org.apache.spark spark-yarn_${scala.version} ${spark.version} org.apache.spark spark-sql_${scala.version} ${spark.version} org.mdkt.compiler InMemoryJavaCompiler 1.3.0 net.alchim31.maven scala-maven-plugin 4.6.1 org.apache.maven.plugins maven-compiler-plugin 3.8.1 1.8 1.8 net.alchim31.maven scala-maven-plugin scala-compile-first process-resources add-source compile scala-test-compile process-test-resources testCompile org.apache.maven.plugins maven-compiler-plugin 1.8 1.8 compile compile org.apache.maven.plugins maven-shade-plugin 2.0 true assembly *:* *:* META-INF/*.SF META-INF/*.DSA META-INF/*.RSA META-INF/services/org.apache.hadoop.fs.FileSystem reference.conf package shade ================================================ FILE: src/Spark.jl ================================================ module Spark include("core.jl") end ================================================ FILE: src/chainable.jl ================================================ """ DotChainer{O, Fn} See `@chainable` for details. """ struct DotChainer{O, Fn} obj::O fn::Fn end # DotChainer(obj, fn) = DotChainer{typeof(obj), typeof(fn)}(obj, fn) (c::DotChainer)(args...) = c.fn(c.obj, args...) """ @chainable T Adds dot chaining syntax to the type, i.e. automatically translate: foo.bar(a) into bar(foo, a) For single-argument functions also support implicit calls, e.g: foo.bar.baz(a, b) is treated the same as: foo.bar().baz(a, b) Note that `@chainable` works by overloading `Base.getproperty()`, making it impossible to customize it for `T`. To have more control, one may use the underlying wrapper type - `DotCaller`. """ macro chainable(T) return quote function Base.getproperty(obj::$(esc(T)), prop::Symbol) if hasfield(typeof(obj), prop) return getfield(obj, prop) elseif isdefined(@__MODULE__, prop) fn = getfield(@__MODULE__, prop) return DotChainer(obj, fn) else error("type $(typeof(obj)) has no field $prop") end end end end function Base.getproperty(dc::DotChainer, prop::Symbol) if hasfield(typeof(dc), prop) return getfield(dc, prop) else # implicitely call function without arguments # and propagate getproperty to the returned object return getproperty(dc(), prop) end end ================================================ FILE: src/column.jl ================================================ ############################################################################### # Column # ############################################################################### function Column(name::String) jcol = jcall(JSQLFunctions, "col", JColumn, (JString,), name) return Column(jcol) end @chainable Column function Base.show(io::IO, col::Column) name = jcall(col.jcol, "toString", JString, ()) print(io, "col(\"$name\")") end # binary with JObject for (func, name) in [(:+, "plus"), (:-, "minus"), (:*, "multiply"), (:/, "divide")] @eval function Base.$func(col::Column, obj::T) where T jres = jcall(col.jcol, $name, JColumn, (JObject,), obj) return Column(jres) end end alias(col::Column, name::String) = Column(jcall(col.jcol, "alias", JColumn, (JString,), name)) asc(col::Column) = Column(jcall(col.jcol, "asc", JColumn, ())) asc_nulls_first(col::Column) = Column(jcall(col.jcol, "asc_nulls_first", JColumn, ())) asc_nulls_last(col::Column) = Column(jcall(col.jcol, "asc_nulls_last", JColumn, ())) between(col::Column, low, up) = Column(jcall(col.jcol, "between", JColumn, (JObject, JObject), low, up)) bitwiseAND(col::Column, other) = Column(jcall(col.jcol, "bitwiseAND", JColumn, (JObject,), other)) Base.:&(col::Column, other) = bitwiseAND(col, other) bitwiseOR(col::Column, other) = Column(jcall(col.jcol, "bitwiseOR", JColumn, (JObject,), other)) Base.:|(col::Column, other) = bitwiseOR(col, other) bitwiseXOR(col::Column, other) = Column(jcall(col.jcol, "bitwiseXOR", JColumn, (JObject,), other)) Base.:⊻(col::Column, other) = bitwiseXOR(col, other) Base.contains(col::Column, other) = Column(jcall(col.jcol, "contains", JColumn, (JObject,), other)) desc(col::Column) = Column(jcall(col.jcol, "desc", JColumn, ())) desc_nulls_first(col::Column) = Column(jcall(col.jcol, "desc_nulls_first", JColumn, ())) desc_nulls_last(col::Column) = Column(jcall(col.jcol, "desc_nulls_last", JColumn, ())) # dropFields should go here, but it's not in listmethods(col.jcol) ¯\_(ツ)_/¯ Base.endswith(col::Column, other) = Column(jcall2(col.jcol, "endsWith", JColumn, (JObject,), other)) Base.endswith(col::Column, other::Column) = Column(jcall(col.jcol, "endsWith", JColumn, (JColumn,), other.jcol)) eqNullSafe(col::Column, other) = Column(jcall(col.jcol, "eqNullSafe", JColumn, (JObject,), other)) Base.:(==)(col::Column, other) = Column(jcall(col.jcol, "equalTo", JColumn, (JObject,), other)) Base.:(!=)(col::Column, other) = Column(jcall(col.jcol, "notEqual", JColumn, (JObject,), other)) explain(col::Column, extended=false) = jcall(col.jcol, "explain", Nothing, (jboolean,), extended) isNotNull(col::Column) = Column(jcall(col.jcol, "isNotNull", JColumn, ())) isNull(col::Column) = Column(jcall(col.jcol, "isNull", JColumn, ())) like(col::Column, s::String) = Column(jcall(col.jcol, "like", JColumn, (JString,), s)) otherwise(col::Column, other) = Column(jcall(col.jcol, "otherwise", JColumn, (JObject,), other)) over(col::Column) = Column(jcall(col.jcol, "over", JColumn, ())) rlike(col::Column, s::String) = Column(jcall(col.jcol, "rlike", JColumn, (JString,), s)) Base.startswith(col::Column, other) = Column(jcall2(col.jcol, "startsWith", JColumn, (JObject,), other)) Base.startswith(col::Column, other::Column) = Column(jcall(col.jcol, "startsWith", JColumn, (JColumn,), other.jcol)) substr(col::Column, start::Column, len::Column) = Column(jcall(col.jcol, "substr", JColumn, (JColumn, JColumn), start.jcol, len.jcol)) substr(col::Column, start::Integer, len::Integer) = Column(jcall(col.jcol, "substr", JColumn, (jint, jint), start, len)) when(col::Column, condition::Column, value) = Column(jcall(col.jcol, "when", JColumn, (JColumn, JObject), condition.jcol, value)) ## JSQLFunctions upper(col::Column) = Column(jcall(JSQLFunctions, "upper", JColumn, (JColumn,), col.jcol)) Base.uppercase(col::Column) = upper(col) lower(col::Column) = Column(jcall(JSQLFunctions, "lower", JColumn, (JColumn,), col.jcol)) Base.lowercase(col::Column) = lower(col) for func in (:min, :max, :count, :sum, :mean) @eval function $func(col::Column) jcol = jcall(JSQLFunctions, string($func), JColumn, (JColumn,), col.jcol) return Column(jcol) end end Base.minimum(col::Column) = min(col) Base.maximum(col::Column) = max(col) avg(col::Column) = mean(col) explode(col::Column) = Column(jcall(JSQLFunctions, "explode", JColumn, (JColumn,), col.jcol)) Base.split(col::Column, sep::AbstractString) = Column(jcall(JSQLFunctions, "split", JColumn, (JColumn, JString), col.jcol, sep)) function window(col::Column, w_dur::String, slide_dur::String, start_time::String) return Column(jcall(JSQLFunctions, "window", JColumn, (JColumn, JString, JString, JString), col.jcol, w_dur, slide_dur, start_time)) end function window(col::Column, w_dur::String, slide_dur::String) return Column(jcall(JSQLFunctions, "window", JColumn, (JColumn, JString, JString), col.jcol, w_dur, slide_dur)) end function window(col::Column, w_dur::String) return Column(jcall(JSQLFunctions, "window", JColumn, (JColumn, JString), col.jcol, w_dur)) end ================================================ FILE: src/compiler.jl ================================================ using JavaCall import JavaCall: assertroottask_or_goodenv, assertloaded using Umlaut const JInMemoryJavaCompiler = @jimport org.mdkt.compiler.InMemoryJavaCompiler # const JDynamicJavaCompiler = @jimport org.apache.spark.api.julia.DynamicJavaCompiler # const JFile = @jimport java.io.File # const JToolProvider = @jimport javax.tools.ToolProvider # const JJavaCompiler = @jimport javax.tools.JavaCompiler # const JInputStream = @jimport java.io.InputStream # const JOutputStream = @jimport java.io.OutputStream # const JClassLoader = @jimport java.lang.ClassLoader # const JURLClassLoader = @jimport java.net.URLClassLoader # const JURI = @jimport java.net.URI # const JURL = @jimport java.net.URL const JUDF1 = @jimport org.apache.spark.sql.api.java.UDF1 ############################################################################### # Compiler # ############################################################################### function create_class(name::String, src::String) jcompiler = jcall(JInMemoryJavaCompiler, "newInstance", JInMemoryJavaCompiler, ()) return jcall(jcompiler, "compile", JClass, (JString, JString), name, src) end function create_instance(name::String, src::String) jclass = create_class(name, src) return jcall(jclass, "newInstance", JObject, ()) end function create_instance(src::String) pkg_name_match = match(r"package ([a-zA-z0-9_\.\$]+);", src) @assert !isnothing(pkg_name_match) "Cannot detect package name in the source:\n\n$src" pkg_name = pkg_name_match.captures[1] class_name_match = match(r"class ([a-zA-z0-9_\$]+)", src) @assert !isnothing(class_name_match) "Cannot detect class name in the source:\n\n$src" class_name = class_name_match.captures[1] return create_instance("$pkg_name.$class_name", src) end ############################################################################### # jcall2 # ############################################################################### function jcall_reflect(jobj::JavaObject, name::String, rettype, argtypes, args...) assertroottask_or_goodenv() && assertloaded() jclass = getclass(jobj) jargs = [a for a in convert.(argtypes, args)] # convert to Vector meth = jcall(jclass, "getMethod", JMethod, (JString, Vector{JClass}), name, getclass.(jargs)) ret = meth(jobj, jargs...) return convert(rettype, ret) end # jcall() fails to call methods of generated classes, jcall2() is a more robust version of it # see https://github.com/JuliaInterop/JavaCall.jl/issues/166 for the details function jcall2(jobj::JavaObject, name::String, rettype, argtypes, args...) try return jcall(jobj, name, rettype, argtypes, args...) catch return jcall_reflect(jobj, name, rettype, argtypes, args...) end end ############################################################################### # JavaExpr # ############################################################################### javastring(::Type{JavaObject{name}}) where name = string(name) javastring(::Nothing) = "" javatype(tape::Tape, v::Variable) = julia2java(typeof(tape[v].val)) javaname(v::Variable) = string(Umlaut.make_name(v.id)) javaname(op::AbstractArray) = javaname(V(op)) javaname(x) = x # literals type_param_string(typeparams::Vector{String}) = isempty(typeparams) ? "" : "<$(join(typeparams, ", "))>" type_param_string(typeparams::Vector) = isempty(typeparams) ? "" : "<$(join(map(javastring, typeparams), ", "))>" abstract type JavaExpr end Base.show(io::IO, ex::JavaExpr) = print(io, javastring(ex)) mutable struct JavaTypeExpr <: JavaExpr class::Type{<:JavaObject} typeparams::Vector # String or Type{<:JavaObject} end JavaTypeExpr(JT::Type{<:JavaObject}) = JavaTypeExpr(JT, []) Base.convert(::Type{JavaTypeExpr}, JT::Type{<:JavaObject}) = JavaTypeExpr(JT) javastring(ex::JavaTypeExpr) = javastring(ex.class) * type_param_string(ex.typeparams) mutable struct JavaCallExpr <: JavaExpr rettype::JavaTypeExpr ret::String this::Union{String, Any} # name or constant method::String args::Vector # names or constants end function javastring(ex::JavaCallExpr) R = javastring(ex.rettype) if !isnothing(match(r"^[\*\/+-]+$", ex.method)) # binary operator return "$R $(ex.ret) = $(ex.this) $(ex.method) $(ex.args[1]);" else return "$R $(ex.ret) = $(ex.this).$(ex.method)($(join(ex.args, ", ")));" end end struct JavaReturnExpr <: JavaExpr ret::String end javastring(ex::JavaReturnExpr) = "return $(ex.ret);" mutable struct JavaMethodExpr <: JavaExpr annotations::Vector{String} rettype::JavaTypeExpr name::String params::Vector{String} paramtypes::Vector{JavaTypeExpr} body::Vector end function javastring(ex::JavaMethodExpr) paramlist = join(["$(javastring(t)) $a" for (a, t) in zip(ex.params, ex.paramtypes)], ", ") result = isempty(ex.annotations) ? "" : "\t" * join(ex.annotations, "\n") * "\n" result *= "public $(javastring(ex.rettype)) $(ex.name)($paramlist) {\n" for subex in ex.body result *= "\t$(javastring(subex))\n" end result *= "}" return result end mutable struct JavaClassExpr <: JavaExpr name::String typeparams::Vector{String} extends::Union{JavaTypeExpr, Nothing} implements::Union{JavaTypeExpr, Nothing} methods::Vector{<:JavaMethodExpr} end function javastring(ex::JavaClassExpr) sep = findlast(".", ex.name) pkg_name, class_name = isnothing(sep) ? ("", ex.name) : (ex.name[1:sep.start-1], ex.name[sep.start+1:end]) pkg_str = isempty(pkg_name) ? "" : "package $pkg_name;" extends_str = isnothing(ex.extends) ? "" : "extends $(javastring(ex.extends))" implements_str = isnothing(ex.implements) ? "" : "implements $(javastring(ex.implements))" methods_str = join(map(javastring, ex.methods), "\n\n") methods_str = replace(methods_str, "\n" => "\n\t") return """ $pkg_str public class $class_name $extends_str $implements_str { $methods_str } """ end ############################################################################### # Tape => JavaExpr # ############################################################################### struct J2JContext end function Umlaut.isprimitive(::J2JContext, f, args...) Umlaut.isprimitive(Umlaut.BaseCtx(), f, args...) && return true modl = parentmodule(typeof(f)) modl in (Spark, Base.Unicode) && return true return false end javamethod(::typeof(+)) = "+" javamethod(::typeof(*)) = "*" javamethod(::typeof(lowercase)) = "toLowerCase" function JavaCallExpr(tape::Tape, op::Call) ret = javaname(V(op)) rettype = javatype(tape, V(op)) this, args... = map(javaname, op.args) method = javamethod(op.fn) return JavaCallExpr(rettype, ret, this, method, args) end function JavaClassExpr(tape::Tape; method_name::String="(unspecified)") fn_name = string(tape[V(1)].val) cls = fn_name * "_" * string(gensym())[3:end] cls = replace(cls, "#" => "_") inp = inputs(tape)[2:end] params = [javaname(v) for v in inp] paramtypes = [javatype(tape, v) for v in inp] ret = javaname(tape.result) rettype = javatype(tape, tape.result) body = JavaExpr[JavaCallExpr(tape, op) for op in tape if !isa(op, Umlaut.Input)] push!(body, JavaReturnExpr(ret)) meth_expr = JavaMethodExpr([], rettype, method_name, params, paramtypes, body) return JavaClassExpr(cls, [], nothing, nothing, [meth_expr]) end ############################################################################### # UDF # ############################################################################### struct UDF src::String judf::JavaObject end Base.show(io::IO, udf::UDF) = print(io, "UDF from:\n\n" * udf.src) function udf(f::Function, args...) val, tape = trace(f, args...; ctx=J2JContext()) class_expr = JavaClassExpr(tape) class_expr.name = "julia2java." * class_expr.name UT = JavaTypeExpr( JavaCall.jimport("org.apache.spark.sql.api.java.UDF$(length(args))"), [javastring(julia2java(typeof(x))) for x in [args...; val]] ) class_expr.implements = UT meth_expr = class_expr.methods[1] meth_expr.name = "call" push!(meth_expr.annotations, "@Override") src = javastring(class_expr) judf = create_instance(src) return UDF(src, judf) end ================================================ FILE: src/convert.jl ================================================ ############################################################################### # Conversions # ############################################################################### # Note: both - java.sql.Timestamp and Julia's DateTime don't have timezone. # But when printing, java.sql.Timestamp will assume UTC and convert to your # local time. To avoid confusion e.g. in REPL, try use fixed date in UTC # or now(Dates.UTC) Base.convert(::Type{JTimestamp}, x::DateTime) = JTimestamp((jlong,), floor(Int, datetime2unix(x)) * 1000) Base.convert(::Type{DateTime}, x::JTimestamp) = unix2datetime(jcall(x, "getTime", jlong, ()) / 1000) Base.convert(::Type{JDate}, x::Date) = JDate((jlong,), floor(Int, datetime2unix(DateTime(x))) * 1000) Base.convert(::Type{Date}, x::JDate) = Date(unix2datetime(jcall(x, "getTime", jlong, ()) / 1000)) Base.convert(::Type{JObject}, x::Integer) = convert(JObject, convert(JLong, x)) Base.convert(::Type{JObject}, x::Real) = convert(JObject, convert(JDouble, x)) Base.convert(::Type{JObject}, x::DateTime) = convert(JObject, convert(JTimestamp, x)) Base.convert(::Type{JObject}, x::Date) = convert(JObject, convert(JDate, x)) Base.convert(::Type{JObject}, x::Column) = convert(JObject, x.jcol) Base.convert(::Type{Row}, obj::JObject) = Row(convert(JRow, obj)) Base.convert(::Type{String}, obj::JString) = unsafe_string(obj) Base.convert(::Type{Integer}, obj::JLong) = jcall(obj, "longValue", jlong, ()) julia2java(::Type{String}) = JString julia2java(::Type{Int64}) = JLong julia2java(::Type{Int32}) = JInt julia2java(::Type{Float64}) = JDouble julia2java(::Type{Float32}) = JFloat julia2java(::Type{Bool}) = JBoolean julia2java(::Type{Any}) = JObject java2julia(::Type{JString}) = String java2julia(::Type{JLong}) = Int64 java2julia(::Type{jlong}) = Int64 java2julia(::Type{JInteger}) = Int32 java2julia(::Type{jint}) = Int32 java2julia(::Type{JDouble}) = Float64 java2julia(::Type{jdouble}) = Float64 java2julia(::Type{JFloat}) = Float32 java2julia(::Type{jfloat}) = Float32 java2julia(::Type{JBoolean}) = Bool java2julia(::Type{jboolean}) = Bool java2julia(::Type{JTimestamp}) = DateTime java2julia(::Type{JDate}) = Date java2julia(::Type{JObject}) = Any julia2ddl(::Type{String}) = "string" julia2ddl(::Type{Int64}) = "long" julia2ddl(::Type{Int32}) = "int" julia2ddl(::Type{Float64}) = "double" julia2ddl(::Type{Float32}) = "float" julia2ddl(::Type{Bool}) = "boolean" julia2ddl(::Type{Dates.Date}) = "date" julia2ddl(::Type{Dates.DateTime}) = "timestamp" function JArray(x::Vector{T}) where T JT = T <: JavaObject ? T : julia2java(T) x = convert(Vector{JT}, x) sz = length(x) init_val = sz == 0 ? C_NULL : Ptr(x[1]) arrayptr = JavaCall.JNI.NewObjectArray(sz, Ptr(JavaCall.metaclass(JT)), init_val) arrayptr === C_NULL && geterror() for i=2:sz JavaCall.JNI.SetObjectArrayElement(arrayptr, i-1, Ptr(x[i])) end return JavaObject{typeof(x)}(arrayptr) end function Base.convert(::Type{JSeq}, x::Vector) jarr = JArray(x) jobj = convert(JObject, jarr) jarrseq = jcall(JArraySeq, "make", JArraySeq, (JObject,), jobj) return jcall(jarrseq, "toSeq", JSeq, ()) # jwa = jcall(JWrappedArray, "make", JWrappedArray, (JObject,), jobj) # jwa = jcall(JArraySeq, "make", JArraySeq, (JObject,), jobj) # return jcall(jwa, "toSeq", JSeq, ()) end function Base.convert(::Type{JMap}, d::Dict) jmap = JHashMap(()) for (k, v) in d jk, jv = convert(JObject, k), convert(JObject, v) jcall(jmap, "put", JObject, (JObject, JObject), jk, jv) end return jmap end ================================================ FILE: src/core.jl ================================================ using JavaCall using Umlaut import Umlaut.V import Statistics using Dates # using TableTraits # using IteratorInterfaceExtensions export SparkSession, DataFrame, GroupedData, Column, Row export StructType, StructField, DataType export Window, WindowSpec include("chainable.jl") include("init.jl") include("compiler.jl") include("defs.jl") include("convert.jl") include("session.jl") include("dataframe.jl") include("column.jl") include("row.jl") include("struct.jl") include("window.jl") include("io.jl") include("streaming.jl") function __init__() init() end # pseudo-modules for some specific functions not exported by default module Compiler using Reexport @reexport import Spark: udf, jcall2, create_instance, create_class end # module SQL # using Reexport # @reexport import Spark: SparkSession, DataFrame, GroupedData, Column, Row # @reexport import Spark: StructType, StructField, DataType # @reexport import Spark: Window, WindowSpec # end ================================================ FILE: src/dataframe.jl ================================================ ############################################################################### # DataFrame # ############################################################################### Base.show(df::DataFrame) = jcall(df.jdf, "show", Nothing, ()) Base.show(df::DataFrame, n::Integer) = jcall(df.jdf, "show", Nothing, (jint,), n) function Base.show(io::IO, df::DataFrame) if df.isstreaming() print(io, toString(df.jdf)) else show(df) end end printSchema(df::DataFrame) = jcall(df.jdf, "printSchema", Nothing, ()) function Base.getindex(df::DataFrame, name::String) jcol = jcall(df.jdf, "col", JColumn, (JString,), name) return Column(jcol) end function Base.getproperty(df::DataFrame, prop::Symbol) if hasfield(DataFrame, prop) return getfield(df, prop) elseif string(prop) in columns(df) return df[string(prop)] else fn = getfield(@__MODULE__, prop) return DotChainer(df, fn) end end function columns(df::DataFrame) jnames = jcall(df.jdf, "columns", Vector{JString}, ()) names = [unsafe_string(jn) for jn in jnames] return names end Base.count(df::DataFrame) = jcall(df.jdf, "count", jlong, ()) Base.first(df::DataFrame) = Row(jcall(df.jdf, "first", JObject, ())) head(df::DataFrame) = Row(jcall(df.jdf, "head", JObject, ())) function head(df::DataFrame, n::Integer) jobjs = jcall(df.jdf, "head", JObject, (jint,), n) jrows = convert(Vector{JRow}, jobjs) return map(Row, jrows) end function Base.collect(df::DataFrame) jobj = jcall(df.jdf, "collect", JObject, ()) jrows = convert(Vector{JRow}, jobj) return map(Row, jrows) end function Base.collect(df::DataFrame, col::Union{<:AbstractString, <:Integer}) rows = collect(df) return [row[col] for row in rows] end function take(df::DataFrame, n::Integer) return convert(Vector{Row}, jcall(df.jdf, "take", JObject, (jint,), n)) end function describe(df::DataFrame, cols::String...) jdf = jcall(df.jdf, "describe", JDataset, (Vector{JString},), collect(cols)) return DataFrame(jdf) end function alias(df::DataFrame, name::String) jdf = jcall(df.jdf, "alias", JDataset, (JString,), name) return DataFrame(jdf) end function select(df::DataFrame, cols::Column...) jdf = jcall(df.jdf, "select", JDataset, (Vector{JColumn},), [col.jcol for col in cols]) return DataFrame(jdf) end select(df::DataFrame, cols::String...) = select(df, map(Column, cols)...) function withColumn(df::DataFrame, name::String, col::Column) jdf = jcall(df.jdf, "withColumn", JDataset, (JString, JColumn), name, col.jcol) return DataFrame(jdf) end function Base.filter(df::DataFrame, col::Column) jdf = jcall(df.jdf, "filter", JDataset, (JColumn,), col.jcol) return DataFrame(jdf) end where(df::DataFrame, col::Column) = filter(df, col) function groupby(df::DataFrame, cols::Column...) jgdf = jcall(df.jdf, "groupBy", JRelationalGroupedDataset, (Vector{JColumn},), [col.jcol for col in cols]) return GroupedData(jgdf) end function groupby(df::DataFrame, col::String, cols::String...) jgdf = jcall(df.jdf, "groupBy", JRelationalGroupedDataset, (JString, Vector{JString},), col, collect(cols)) return GroupedData(jgdf) end const groupBy = groupby for func in (:min, :max, :count, :sum, :mean) @eval function $func(df::DataFrame, cols::String...) jdf = jcall(df.jdf, string($func), JDataset, (Vector{JString},), collect(cols)) return DataFrame(jdf) end end minimum(df::DataFrame, cols::String...) = min(df, cols...) maximum(df::DataFrame, cols::String...) = max(df, cols...) avg(df::DataFrame, cols::String...) = mean(df, cols...) function Base.join(df1::DataFrame, df2::DataFrame, col::Column, typ::String="inner") jdf = jcall(df1.jdf, "join", JDataset, (JDataset, JColumn, JString), df2.jdf, col.jcol, typ) return DataFrame(jdf) end createOrReplaceTempView(df::DataFrame, name::AbstractString) = jcall(df.jdf, "createOrReplaceTempView", Nothing, (JString,), name) isstreaming(df::DataFrame) = Bool(jcall(df.jdf, "isStreaming", jboolean, ())) isStreaming(df::DataFrame) = isstreaming(df) function writeStream(df::DataFrame) jwriter = jcall(df.jdf, "writeStream", JDataStreamWriter, ()) return DataStreamWriter(jwriter) end ############################################################################### # GroupedData # ############################################################################### @chainable GroupedData function Base.show(io::IO, gdf::GroupedData) repr = jcall(gdf.jgdf, "toString", JString, ()) repr = replace(repr, "RelationalGroupedDataset" => "GroupedData") print(io, repr) end function agg(gdf::GroupedData, col::Column, cols::Column...) jdf = jcall(gdf.jgdf, "agg", JDataset, (JColumn, Vector{JColumn}), col.jcol, [col.jcol for col in cols]) return DataFrame(jdf) end function agg(gdf::GroupedData, ops::Dict{<:AbstractString, <:AbstractString}) jmap = convert(JMap, ops) jdf = jcall(gdf.jgdf, "agg", JDataset, (JMap,), jmap) return DataFrame(jdf) end for func in (:min, :max, :sum, :mean) @eval function $func(gdf::GroupedData, cols::String...) jdf = jcall(gdf.jgdf, string($func), JDataset, (Vector{JString},), collect(cols)) return DataFrame(jdf) end end minimum(gdf::GroupedData, cols::String...) = min(gdf, cols...) maximum(gdf::GroupedData, cols::String...) = max(gdf, cols...) avg(gdf::GroupedData, cols::String...) = mean(gdf, cols...) Base.count(gdf::GroupedData) = DataFrame(jcall(gdf.jgdf, "count", JDataset, ())) function write(df::DataFrame) jwriter = jcall(df.jdf, "write", JDataFrameWriter, ()) return DataFrameWriter(jwriter) end ================================================ FILE: src/defs.jl ================================================ import Base: min, max, minimum, maximum, sum, count import Statistics: mean const JSparkConf = @jimport org.apache.spark.SparkConf const JRuntimeConfig = @jimport org.apache.spark.sql.RuntimeConfig const JSparkContext = @jimport org.apache.spark.SparkContext const JJavaSparkContext = @jimport org.apache.spark.api.java.JavaSparkContext const JRDD = @jimport org.apache.spark.rdd.RDD const JJavaRDD = @jimport org.apache.spark.api.java.JavaRDD const JSparkSession = @jimport org.apache.spark.sql.SparkSession const JSparkSessionBuilder = @jimport org.apache.spark.sql.SparkSession$Builder const JDataFrameReader = @jimport org.apache.spark.sql.DataFrameReader const JDataFrameWriter = @jimport org.apache.spark.sql.DataFrameWriter const JDataStreamReader = @jimport org.apache.spark.sql.streaming.DataStreamReader const JDataStreamWriter = @jimport org.apache.spark.sql.streaming.DataStreamWriter const JStreamingQuery = @jimport org.apache.spark.sql.streaming.StreamingQuery const JDataset = @jimport org.apache.spark.sql.Dataset const JRelationalGroupedDataset = @jimport org.apache.spark.sql.RelationalGroupedDataset # const JRowFactory = @jimport org.apache.spark.sql.RowFactory const JGenericRow = @jimport org.apache.spark.sql.catalyst.expressions.GenericRow const JGenericRowWithSchema = @jimport org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema const JRow = @jimport org.apache.spark.sql.Row const JColumn = @jimport org.apache.spark.sql.Column const JDataType = @jimport org.apache.spark.sql.types.DataType const JMetadata = @jimport org.apache.spark.sql.types.Metadata const JStructType = @jimport org.apache.spark.sql.types.StructType const JStructField = @jimport org.apache.spark.sql.types.StructField const JSQLFunctions = @jimport org.apache.spark.sql.functions const JWindow = @jimport org.apache.spark.sql.expressions.Window const JWindowSpec = @jimport org.apache.spark.sql.expressions.WindowSpec const JInteger = @jimport java.lang.Integer const JLong = @jimport java.lang.Long const JFloat = @jimport java.lang.Float const JDouble = @jimport java.lang.Double const JBoolean = @jimport java.lang.Boolean const JDate = @jimport java.sql.Date const JTimestamp = @jimport java.sql.Timestamp const JMap = @jimport java.util.Map const JHashMap = @jimport java.util.HashMap const JList = @jimport java.util.List const JArrayList = @jimport java.util.ArrayList # const JWrappedArray = @jimport scala.collection.mutable.WrappedArray const JArraySeq = @jimport scala.collection.mutable.ArraySeq const JSeq = @jimport scala.collection.immutable.Seq toString(jobj::JavaObject) = jcall(jobj, "toString", JString, ()) ############################################################################### # Type Definitions # ############################################################################### "Builder for [`SparkSession`](@ref)" struct SparkSessionBuilder jbuilder::JSparkSessionBuilder end "The entry point to programming Spark with the Dataset and DataFrame API" struct SparkSession jspark::JSparkSession end "User-facing configuration API, accessible through SparkSession.conf" struct RuntimeConfig jconf::JRuntimeConfig end "A distributed collection of data grouped into named columns" struct DataFrame jdf::JDataset end "A set of methods for aggregations on a `DataFrame`, created by `DataFrame.groupBy()`" struct GroupedData # here we use PySpark's type name, not the underlying Scala's name jgdf::JRelationalGroupedDataset end "A column in a DataFrame" struct Column jcol::JColumn end "A row in DataFrame" struct Row jrow::JRow end "Struct type, consisting of a list of [`StructField`](@ref)" struct StructType jst::JStructType end "A field in [`StructType`](@ref)" struct StructField jsf::JStructField end "Utility functions for defining window in DataFrames" struct Window jwin::JWindow end "A window specification that defines the partitioning, ordering, and frame boundaries" struct WindowSpec jwin::JWindowSpec end "Interface used to load a `DataFrame` from external storage systems" struct DataFrameReader jreader::JDataFrameReader end "Interface used to write a `DataFrame` to external storage systems" struct DataFrameWriter jwriter::JDataFrameWriter end "Interface used to load a streaming `DataFrame` from external storage systems" struct DataStreamReader jreader::JDataStreamReader end "Interface used to write a streaming `DataFrame` to external" struct DataStreamWriter jwriter::JDataStreamWriter end "A handle to a query that is executing continuously in the background as new data arrives" struct StreamingQuery jquery::JStreamingQuery end ================================================ FILE: src/init.jl ================================================ const JSystem = @jimport java.lang.System global const SPARK_DEFAULT_PROPS = Dict() function set_log_level(log_level::String) JLogger = @jimport org.apache.log4j.Logger JLevel = @jimport org.apache.log4j.Level level = jfield(JLevel, log_level, JLevel) for logger_name in ("org", "akka") logger = jcall(JLogger, "getLogger", JLogger, (JString,), logger_name) jcall(logger, "setLevel", Nothing, (JLevel,), level) end end function init(; log_level="WARN") if JavaCall.isloaded() @warn "JVM already initialized, this call will have no effect" return end JavaCall.addClassPath(get(ENV, "CLASSPATH", "")) defaults = load_spark_defaults(SPARK_DEFAULT_PROPS) shome = get(ENV, "SPARK_HOME", "") if !isempty(shome) for x in readdir(joinpath(shome, "jars")) JavaCall.addClassPath(joinpath(shome, "jars", x)) end JavaCall.addClassPath(joinpath(dirname(@__FILE__), "..", "jvm", "sparkjl", "target", "sparkjl-0.2.jar")) else JavaCall.addClassPath(joinpath(dirname(@__FILE__), "..", "jvm", "sparkjl", "target", "sparkjl-0.2-assembly.jar")) end for y in split(get(ENV, "SPARK_DIST_CLASSPATH", ""), [':',';'], keepempty=false) JavaCall.addClassPath(String(y)) end for z in split(get(defaults, "spark.driver.extraClassPath", ""), [':',';'], keepempty=false) JavaCall.addClassPath(String(z)) end JavaCall.addClassPath(get(ENV, "HADOOP_CONF_DIR", "")) JavaCall.addClassPath(get(ENV, "YARN_CONF_DIR", "")) if get(ENV, "HDP_VERSION", "") == "" try ENV["HDP_VERSION"] = pipeline(`hdp-select status` , `grep spark2-client` , `awk -F " " '{print $3}'`) |> (cmd -> read(cmd, String)) |> strip catch end end for y in split(get(defaults, "spark.driver.extraJavaOptions", ""), " ", keepempty=false) JavaCall.addOpts(String(y)) end s = get(defaults, "spark.driver.extraLibraryPath", "") try JavaCall.addOpts("-Djava.library.path=$(defaults["spark.driver.extraLibraryPath"])") catch; end JavaCall.addOpts("-ea") JavaCall.addOpts("-Xmx1024M") JavaCall.init() validateJavaVersion() set_log_level(log_level) end function validateJavaVersion() version::String = jcall(JSystem, "getProperty", JString, (JString,), "java.version") if !startswith(version, "1.8") && !startswith(version, "11.") @warn "Java 1.8 or 1.11 is recommended for Spark.jl, but Java $version was used." end end function load_spark_defaults(d::Dict) sconf = get(ENV, "SPARK_CONF", "") if sconf == "" shome = get(ENV, "SPARK_HOME", "") if shome == "" ; return d; end sconf = joinpath(shome, "conf") end spark_defaults_locs = [joinpath(sconf, "spark-defaults.conf"), joinpath(sconf, "spark-defaults.conf.template")] conf_idx = findfirst(isfile, spark_defaults_locs) if conf_idx == 0 error("Can't find spark-defaults.conf, looked at: $spark_defaults_locs") else spark_defaults_conf = spark_defaults_locs[conf_idx] end p = split(Base.read(spark_defaults_conf, String), '\n', keepempty=false) for x in p if !startswith(x, "#") && !isempty(strip(x)) y=split(x, limit=2) if size(y,1)==1 y=split(x, "=", limit=2) end d[y[1]]=strip(y[2]) end end return d end ================================================ FILE: src/io.jl ================================================ ############################################################################### # DataFrameReader # ############################################################################### @chainable DataFrameReader Base.show(io::IO, ::DataFrameReader) = print(io, "DataFrameReader()") function format(reader::DataFrameReader, src::String) jcall(reader.jreader, "format", JDataFrameReader, (JString,), src) return reader end for (T, JT) in [(String, JString), (Integer, jlong), (Real, jdouble), (Bool, jboolean)] @eval function option(reader::DataFrameReader, key::String, value::$T) jcall(reader.jreader, "option", JDataFrameReader, (JString, $JT), key, value) return reader end end for func in (:csv, :json, :parquet, :orc, :text, :textFile) @eval function $func(reader::DataFrameReader, paths::String...) jdf = jcall(reader.jreader, string($func), JDataset, (Vector{JString},), collect(paths)) return DataFrame(jdf) end end function load(reader::DataFrameReader, paths::String...) # TODO: test with zero paths jdf = jcall(reader.jreader, "load", JDataset, (Vector{JString},), collect(paths)) return DataFrame(jdf) end ############################################################################### # DataFrameWriter # ############################################################################### @chainable DataFrameWriter Base.show(io::IO, ::DataFrameWriter) = print(io, "DataFrameWriter()") function format(writer::DataFrameWriter, fmt::String) jcall(writer.jwriter, "format", JDataFrameWriter, (JString,), fmt) return writer end function mode(writer::DataFrameWriter, m::String) jcall(writer.jwriter, "mode", JDataFrameWriter, (JString,), m) return writer end for (T, JT) in [(String, JString), (Integer, jlong), (Real, jdouble), (Bool, jboolean)] @eval function option(writer::DataFrameWriter, key::String, value::$T) jcall(writer.jwriter, "option", JDataFrameWriter, (JString, $JT), key, value) return writer end end for func in (:csv, :json, :parquet, :orc, :text) @eval function $func(writer::DataFrameWriter, path::String) jcall(writer.jwriter, string($func), Nothing, (JString,), path) end end ================================================ FILE: src/row.jl ================================================ ############################################################################### # Row # ############################################################################### function Row(; kv...) ks = map(string, keys(kv)) vs = collect(values(values(kv))) flds = [StructField(k, julia2ddl(typeof(v)), true) for (k, v) in zip(ks, vs)] st = StructType(flds...) jrow = JGenericRowWithSchema((Vector{JObject}, JStructType,), vs, st.jst) jrow = convert(JRow, jrow) return Row(jrow) end function Row(vals::Vector) jseq = convert(JSeq, vals) jrow = jcall(JRow, "fromSeq", JRow, (JSeq,), jseq) return Row(jrow) end function Row(vals...) return Row(collect(vals)) end function Base.show(io::IO, row::Row) str = jcall(row.jrow, "toString", JString, ()) print(io, str) end function Base.getindex(row::Row, i::Integer) jobj = jcall(row.jrow, "get", JObject, (jint,), i - 1) class_name = getname(getclass(jobj)) JT = JavaObject{Symbol(class_name)} T = java2julia(JT) return convert(T, convert(JT, jobj)) # TODO: test all 4 types end function Base.getindex(row::Row, name::String) i = jcall(row.jrow, "fieldIndex", jint, (JString,), name) return row[i + 1] end function schema(row::Row) jst = jcall(row.jrow, "schema", JStructType, ()) return isnull(jst) ? nothing : StructType(jst) end function Base.getproperty(row::Row, prop::Symbol) if hasfield(Row, prop) return getfield(row, prop) end sch = schema(row) if !isnothing(sch) && string(prop) in names(sch) return row[string(prop)] else fn = getfield(@__MODULE__, prop) return DotChainer(row, fn) end end Base.:(==)(row1::Row, row2::Row) = Bool(jcall(row1.jrow, "equals", jboolean, (JObject,), row2.jrow)) ================================================ FILE: src/session.jl ================================================ ############################################################################### # SparkSession.Builder # ############################################################################### @chainable SparkSessionBuilder Base.show(io::IO, ::SparkSessionBuilder) = print(io, "SparkSessionBuilder()") function appName(builder::SparkSessionBuilder, name::String) jcall(builder.jbuilder, "appName", JSparkSessionBuilder, (JString,), name) return builder end function master(builder::SparkSessionBuilder, uri::String) jcall(builder.jbuilder, "master", JSparkSessionBuilder, (JString,), uri) return builder end for JT in (JString, JDouble, JLong, JBoolean) T = java2julia(JT) @eval function config(builder::SparkSessionBuilder, key::String, value::$T) jcall(builder.jbuilder, "config", JSparkSessionBuilder, (JString, $JT), key, value) return builder end end function enableHiveSupport(builder::SparkSessionBuilder) jcall(builder.jbuilder, "enableHiveSupport", JSparkSessionBuilder, ()) return builder end function getOrCreate(builder::SparkSessionBuilder) config(builder, "spark.jars", joinpath(dirname(@__FILE__), "..", "jvm", "sparkjl", "target", "sparkjl-0.2.jar")) jspark = jcall(builder.jbuilder, "getOrCreate", JSparkSession, ()) return SparkSession(jspark) end ############################################################################### # SparkSession # ############################################################################### @chainable SparkSession Base.show(io::IO, ::SparkSession) = print(io, "SparkSession()") function Base.getproperty(::Type{SparkSession}, prop::Symbol) if prop == :builder jbuilder = jcall(JSparkSession, "builder", JSparkSessionBuilder, ()) return SparkSessionBuilder(jbuilder) else return getfield(SparkSession, prop) end end Base.close(spark::SparkSession) = jcall(spark.jspark, "close", Nothing, ()) stop(spark::SparkSession) = jcall(spark.jspark, "stop", Nothing, ()) function read(spark::SparkSession) jreader = jcall(spark.jspark, "read", JDataFrameReader, ()) return DataFrameReader(jreader) end # note: write() method is defined in dataframe.jl # runtime config function conf(spark::SparkSession) jconf = jcall(spark.jspark, "conf", JRuntimeConfig, ()) return RuntimeConfig(jconf) end function createDataFrame(spark::SparkSession, rows::Vector{Row}, sch::StructType) if !isempty(rows) row = rows[1] rsch = row.schema() if !isnothing(rsch) && rsch != sch @warn "Schema mismatch:\n\trow : $(row.schema())\n\tprovided: $sch" end end jrows = [row.jrow for row in rows] jrows_arr = convert(JArrayList, jrows) jdf = jcall(spark.jspark, "createDataFrame", JDataset, (JList, JStructType), jrows_arr, sch.jst) return DataFrame(jdf) end function createDataFrame(spark::SparkSession, rows::Vector{Row}, sch::Union{String, Vector{String}}) st = StructType(sch) return spark.createDataFrame(rows, st) end function createDataFrame(spark::SparkSession, data::Vector{Vector{Any}}, sch::Union{String, Vector{String}}) rows = map(Row, data) st = StructType(sch) return spark.createDataFrame(rows, st) end function createDataFrame(spark::SparkSession, rows::Vector{Row}) @assert !isempty(rows) "Cannot create a DataFrame from empty list of rows" st = rows[1].schema() return spark.createDataFrame(rows, st) end function sql(spark::SparkSession, query::String) jdf = jcall(spark.jspark, "sql", JDataset, (JString,), query) return DataFrame(jdf) end ############################################################################### # RuntimeConfig # ############################################################################### @chainable RuntimeConfig Base.show(io::IO, cnf::RuntimeConfig) = print(io, "RuntimeConfig()") Base.get(cnf::RuntimeConfig, name::String) = jcall(cnf.jconf, "get", JString, (JString,), name) Base.get(cnf::RuntimeConfig, name::String, default::String) = jcall(cnf.jconf, "get", JString, (JString, JString), name, default) function getAll(cnf::RuntimeConfig) jmap = jcall(cnf.jconf, "getAll", @jimport(scala.collection.immutable.Map), ()) jiter = jcall(jmap, "iterator", @jimport(scala.collection.Iterator), ()) ret = Dict{String, Any}() while Bool(jcall(jiter, "hasNext", jboolean, ())) jobj = jcall(jiter, "next", JObject, ()) e = convert(@jimport(scala.Tuple2), jobj) key = convert(JString, jcall(e, "_1", JObject, ())) |> unsafe_string jval = jcall(e, "_2", JObject, ()) cls_name = getname(getclass(jval)) val = if cls_name == "java.lang.String" unsafe_string(convert(JString, jval)) else "(value type $cls_name is not supported)" end ret[key] = val end return ret end for JT in (JString, jlong, jboolean) T = java2julia(JT) @eval function set(cnf::RuntimeConfig, key::String, value::$T) jcall(cnf.jconf, "set", Nothing, (JString, $JT), key, value) end end ================================================ FILE: src/streaming.jl ================================================ ############################################################################### # DataStreamReader # ############################################################################### Base.show(io::IO, stream::DataStreamReader) = print(io, "DataStreamReader()") @chainable DataStreamReader function readStream(spark::SparkSession) jreader = jcall(spark.jspark, "readStream", JDataStreamReader, ()) return DataStreamReader(jreader) end function format(stream::DataStreamReader, fmt::String) jreader = jcall(stream.jreader, "format", JDataStreamReader, (JString,), fmt) return DataStreamReader(jreader) end function schema(stream::DataStreamReader, sch::StructType) jreader = jcall(stream.jreader, "schema", JDataStreamReader, (JStructType,), sch.jst) return DataStreamReader(jreader) end function schema(stream::DataStreamReader, sch::String) jreader = jcall(stream.jreader, "schema", JDataStreamReader, (JString,), sch) return DataStreamReader(jreader) end for (T, JT) in [(String, JString), (Integer, jlong), (Real, jdouble), (Bool, jboolean)] @eval function option(stream::DataStreamReader, key::String, value::$T) jcall(stream.jreader, "option", JDataStreamReader, (JString, $JT), key, value) return stream end end for func in (:csv, :json, :parquet, :orc, :text, :textFile) @eval function $func(stream::DataStreamReader, path::String) jdf = jcall(stream.jreader, string($func), JDataset, (JString,), path) return DataFrame(jdf) end end function load(stream::DataStreamReader, path::String) jdf = jcall(stream.jreader, "load", JDataset, (JString,), path) return DataFrame(jdf) end function load(stream::DataStreamReader) jdf = jcall(stream.jreader, "load", JDataset, ()) return DataFrame(jdf) end ############################################################################### # DataStreamWriter # ############################################################################### Base.show(io::IO, stream::DataStreamWriter) = print(io, "DataStreamWriter()") @chainable DataStreamWriter function format(writer::DataStreamWriter, fmt::String) jcall(writer.jwriter, "format", JDataStreamWriter, (JString,), fmt) return writer end function outputMode(writer::DataStreamWriter, m::String) jcall(writer.jwriter, "outputMode", JDataStreamWriter, (JString,), m) return writer end for (T, JT) in [(String, JString), (Integer, jlong), (Real, jdouble), (Bool, jboolean)] @eval function option(writer::DataStreamWriter, key::String, value::$T) jcall(writer.jwriter, "option", JDataStreamWriter, (JString, $JT), key, value) return writer end end function foreach(writer::DataStreamWriter, jfew::JObject) # Spark doesn't automatically distribute dynamically created objects to workers # Thus I turn off this feature for now error("Not implemented yet") # JForeachWriter = @jimport(org.apache.spark.sql.ForeachWriter) # jfew = convert(JForeachWriter, jfew) # jwriter = jcall(writer.jwriter, "foreach", JDataStreamWriter, (JForeachWriter,), jfew) # return DataStreamWriter(jwriter) end function start(writer::DataStreamWriter) jquery = jcall(writer.jwriter, "start", JStreamingQuery, ()) return StreamingQuery(jquery) end ############################################################################### # StreamingQuery # ############################################################################### Base.show(io::IO, query::StreamingQuery) = print(io, "StreamingQuery()") @chainable StreamingQuery function awaitTermination(query::StreamingQuery) jcall(query.jquery, "awaitTermination", Nothing, ()) end function awaitTermination(query::StreamingQuery, timeout::Integer) return Bool(jcall(query.jquery, "awaitTermination", jboolean, (jlong,), timeout)) end isActive(query::StreamingQuery) = Bool(jcall(query.jquery, "isActive", jboolean, ())) stop(query::StreamingQuery) = jcall(query.jquery, "stop", Nothing, ()) explain(query::StreamingQuery) = jcall(query.jquery, "explain", Nothing, ()) explain(query::StreamingQuery, extended::Bool) = jcall(query.jquery, "explain", Nothing, (jboolean,), extended) # TODO: foreach, foreachBatch ================================================ FILE: src/struct.jl ================================================ ############################################################################### # StructType # ############################################################################### StructType() = StructType(JStructType(())) function StructType(flds::StructField...) st = StructType() for fld in flds st = add(st, fld) end return st end function StructType(sch::Vector{<:AbstractString}) flds = StructField[] for name_ddl in sch name, ddl = split(strip(name_ddl), " ") push!(flds, StructField(name, ddl, true)) end return StructType(flds...) end function StructType(sch::String) return StructType(split(sch, ",")) end @chainable StructType Base.show(io::IO, st::StructType) = print(io, jcall(st.jst, "toString", JString, ())) fieldNames(st::StructType) = convert(Vector{String}, jcall(st.jst, "fieldNames", Vector{JString}, ())) Base.names(st::StructType) = fieldNames(st) add(st::StructType, sf::StructField) = StructType(jcall(st.jst, "add", JStructType, (JStructField,), sf.jsf)) Base.getindex(st::StructType, idx::Integer) = StructField(jcall(st.jst, "apply", JStructField, (jint,), idx - 1)) Base.getindex(st::StructType, name::String) = StructField(jcall(st.jst, "apply", JStructField, (JString,), name)) Base.:(==)(st1::StructType, st2::StructType) = Bool(jcall(st1.jst, "equals", jboolean, (JObject,), st2.jst)) ############################################################################### # StructField # ############################################################################### function StructField(name::AbstractString, typ::AbstractString, nullable::Bool) dtyp = jcall(JDataType, "fromDDL", JDataType, (JString,), typ) empty_metadata = jcall(JMetadata, "empty", JMetadata, ()) jsf = jcall( JStructField, "apply", JStructField, (JString, JDataType, jboolean, JMetadata), name, dtyp, nullable, empty_metadata ) return StructField(jsf) end Base.show(io::IO, sf::StructField) = print(io, jcall(sf.jsf, "toString", JString, ())) Base.:(==)(st1::StructField, st2::StructField) = Bool(jcall(st1.jsf, "equals", jboolean, (JObject,), st2.jsf)) ================================================ FILE: src/window.jl ================================================ ############################################################################### # Window & WindowSpec # ############################################################################### @chainable WindowSpec function Base.getproperty(W::Type{Window}, prop::Symbol) if hasfield(typeof(W), prop) return getfield(W, prop) elseif prop in (:currentRow, :unboundedFollowing, :unboundedPreceding) return jcall(JWindow, string(prop), jlong, ()) else fn = getfield(@__MODULE__, prop) return DotChainer(W, fn) end end Base.show(io::IO, win::Window) = print(io, "Window()") Base.show(io::IO, win::WindowSpec) = print(io, "WindowSpec()") for (WT, jobj) in [(WindowSpec, :(win.jwin)), (Type{Window}, JWindow)] @eval function orderBy(win::$WT, cols::Column...) jwin = jcall($jobj, "orderBy", JWindowSpec, (Vector{JColumn},), [col.jcol for col in cols]) return WindowSpec(jwin) end @eval function orderBy(win::$WT, col::String, cols::String...) jwin = jcall($jobj, "orderBy", JWindowSpec, (JString, Vector{JString},), col, collect(cols)) return WindowSpec(jwin) end @eval function partitionBy(win::$WT, cols::Column...) jwin = jcall($jobj, "partitionBy", JWindowSpec, (Vector{JColumn},), [col.jcol for col in cols]) return WindowSpec(jwin) end @eval function partitionBy(win::$WT, col::String, cols::String...) jwin = jcall($jobj, "partitionBy", JWindowSpec, (JString, Vector{JString},), col, collect(cols)) return WindowSpec(jwin) end @eval function rangeBetween(win::$WT, start::Column, finish::Column) jwin = jcall($jobj, "rangeBetween", JWindowSpec, (JColumn, JColumn), start.jcol, finish.jcol) return WindowSpec(jwin) end @eval function rangeBetween(win::$WT, start::Integer, finish::Integer) jwin = jcall($jobj, "rangeBetween", JWindowSpec, (jlong, jlong), start, finish) return WindowSpec(jwin) end @eval function rowsBetween(win::$WT, start::Column, finish::Column) jwin = jcall($jobj, "rowsBetween", JWindowSpec, (JColumn, JColumn), start.jcol, finish.jcol) return WindowSpec(jwin) end @eval function rowsBetween(win::$WT, start::Integer, finish::Integer) jwin = jcall($jobj, "rowsBetween", JWindowSpec, (jlong, jlong), start, finish) return WindowSpec(jwin) end end ================================================ FILE: test/data/people.json ================================================ [{"name": "Peter", "age": 32}, {"name": "Belle", "age": 27}] ================================================ FILE: test/data/people2.json ================================================ [{"name": "Peter", "age": 32}, {"name": "Belle", "age": 27}, {"name": "Peter", "age": 27}] ================================================ FILE: test/runtests.jl ================================================ if Sys.isunix() ENV["JULIA_COPY_STACKS"] = 1 end using Test using Spark import Statistics.mean Spark.set_log_level("ERROR") spark = Spark.SparkSession.builder. appName("Hello"). master("local"). config("some.key", "some-value"). getOrCreate() include("test_chainable.jl") include("test_convert.jl") include("test_compiler.jl") include("test_sql.jl") spark.stop() # include("rdd/test_rdd.jl") ================================================ FILE: test/test_chainable.jl ================================================ import Spark: @chainable struct Foo x::Int end @chainable Foo struct Bar a::Int end @chainable Bar add(foo::Foo, y) = foo.x + y to_bar(foo::Foo) = Bar(foo.x) mul(bar::Bar, b) = bar.a * b @testset "chainable" begin foo = Foo(2.0); y = rand(); b = rand() # field access @test foo.x == 2.0 # dot syntax @test foo.add(y) == add(foo, y) # chained field access @test foo.to_bar().a == 2.0 # chained dot syntax @test foo.to_bar().mul(b) == mul(foo.to_bar(), b) # implicit call @test foo.to_bar.mul(b) == mul(foo.to_bar(), b) # correct type @test foo.to_bar isa Spark.DotChainer end ================================================ FILE: test/test_compiler.jl ================================================ import Spark: jcall2, udf import Spark.JavaCall: @jimport, jdouble, JString const JDouble = @jimport java.lang.Double @testset "Compiler" begin f = (x, y) -> 2x + y f_udf = udf(f, 2.0, 3.0) r = jcall2(f_udf.judf, "call", jdouble, (JDouble, JDouble), 5.0, 6.0) @test r == f(5.0, 6.0) f = s -> lowercase(s) f_udf = udf(f, "Hi!") r = jcall2(f_udf.judf, "call", JString, (JString,), "Big Buddha Boom!") @test convert(String, r) == f("Big Buddha Boom!") end ================================================ FILE: test/test_convert.jl ================================================ using Dates @testset "Convert" begin # create DateTime without fractional part t = now(Dates.UTC) |> datetime2unix |> floor |> unix2datetime d = Date(t) @test convert(DateTime, convert(Spark.JTimestamp, t)) == t @test convert(Date, convert(Spark.JDate, d)) == d end ================================================ FILE: test/test_sql.jl ================================================ using Spark using Spark.Compiler @testset "Builder" begin cnf = spark.conf.getAll() @test cnf["spark.app.name"] == "Hello" @test cnf["spark.master"] == "local" @test cnf["some.key"] == "some-value" end @testset "SparkSession" begin df = spark.sql("select 1 as num") @test df.collect("num") == [1] end @testset "RuntimeConfig" begin @test spark.conf.get("spark.app.name") == "Hello" spark.conf.set("another.key", "another-value") @test spark.conf.get("another.key") == "another-value" @test spark.conf.get("non.existing", "default-value") == "default-value" end @testset "DataFrame" begin rows = [Row(name="Alice", age=12), Row(name="Bob", age=32)] @test spark.createDataFrame(rows) isa DataFrame @test spark.createDataFrame(rows, StructType("name string, age long")) isa DataFrame df = spark.createDataFrame(rows) @test df.columns() == ["name", "age"] @test df.first() == rows[1] @test df.head() == rows[1] @test df.head(2) == rows @test df.take(1) == rows[1:1] @test df.collect() == rows @test df.count() == 2 @test df.select("age", "name").columns() == ["age", "name"] rows = df.select(Column("age") + 1).collect() @test [row[1] for row in rows] == [13, 33] rows = df.withColumn("inc_age", df.age + 1).collect() @test [row[3] for row in rows] == [13, 33] @test df.filter(df.name == "Alice").first().age == 12 @test df.where(df.name == "Alice").first().age == 12 df2 = spark.createDataFrame( [Any["Alice", "Smith"], ["Emily", "Clark"]], "first_name string, last_name string" ) joined_df = df.join(df2, df.name == df2.first_name) @test joined_df.columns() == ["name", "age", "first_name", "last_name"] @test joined_df.count() == 1 joined_df = df.join(df2, df.name == df2.first_name, "outer") @test joined_df.count() == 3 df.createOrReplaceTempView("people") @test spark.sql("select count(*) from people").first()[1] == 2 end @testset "GroupedData" begin data = [ ["red", "banana", 1, 10], ["blue", "banana", 2, 20], ["red", "carrot", 3, 30], ["blue", "grape", 4, 40], ["red", "carrot", 5, 50], ["black", "carrot", 6, 60], ["red", "banana", 7, 70], ["red", "grape", 8, 80] ] sch = ["color string", "fruit string", "v1 long", "v2 long"] df = spark.createDataFrame(data, sch) gdf = df.groupby("fruit") @test gdf isa GroupedData df_agg = gdf.agg(min(df.v1), max(df.v2)) @test df_agg.collect("min(v1)") == [4, 1, 3] @test df_agg.collect("max(v2)") == [80, 70, 60] df_agg = gdf.agg(Dict("v1" => "min", "v2" => "max")) @test df_agg.collect("min(v1)") == [4, 1, 3] @test df_agg.collect("max(v2)") == [80, 70, 60] @test gdf.sum("v1").select(mean(Column("sum(v1)"))).collect(1)[1] == 12.0 end @testset "Column" begin col = Column("x") for func in (+, -, *, /) @test func(col, 1) isa Column @test func(col, 1.0) isa Column end @test col.alias("y") isa Column @test col.asc() isa Column @test col.asc_nulls_first() isa Column @test col.asc_nulls_last() isa Column @test col.between(1, 2) isa Column @test col.bitwiseAND(1) isa Column @test col & 1 isa Column @test col.bitwiseOR(1) isa Column @test col | 1 isa Column @test col.bitwiseXOR(1) isa Column @test col ⊻ 1 isa Column @test col.contains("a") isa Column @test col.desc() isa Column @test col.desc_nulls_first() isa Column @test col.desc_nulls_last() isa Column # prints 'Exception in thread "main" java.lang.NoSuchMethodError: endsWith' # but seems to work @test col.endswith("a") isa Column @test col.endswith(Column("other")) isa Column @test col.eqNullSafe("other") isa Column @test (col == Column("other")) isa Column @test (col == "abc") isa Column @test (col != Column("other")) isa Column @test (col != "abc") isa Column col.explain() # smoke test @test col.isNull() isa Column @test col.isNotNull() isa Column @test col.like("abc") isa Column @test col.rlike("abc") isa Column @test_broken col.when(Column("flag"), 1).otherwise("abc") isa Colon @test col.over() isa Column # also complains about NoSuchMethodError, but seems to work @test col.startswith("a") isa Column @test col.startswith(Column("other")) isa Column @test col.substr(Column("start"), Column("len")) isa Column @test col.substr(0, 3) isa Column @test col.explode() |> string == """col("explode(x)")""" @test col.split("|") |> string == """col("split(x, |, -1)")""" @test (col.window("10 minutes", "5 minutes", "15 minutes") |> string == """col("window(x, 600000000, 300000000, 900000000) AS window")""") @test (col.window("10 minutes", "5 minutes") |> string == """col("window(x, 600000000, 300000000, 0) AS window")""") @test (col.window("10 minutes") |> string == """col("window(x, 600000000, 600000000, 0) AS window")""") end @testset "StructType" begin st = StructType() @test length(st.fieldNames()) == 0 st = StructType( StructField("name", "string", false), StructField("age", "int", true) ) @test st[1] == StructField("name", "string", false) end @testset "Window" begin # how can we do these tests more robust? @test Window.partitionBy(Column("x")).orderBy(Column("y").desc()) isa WindowSpec @test Window.partitionBy("x").orderBy("y") isa WindowSpec @test Window.partitionBy("x").orderBy("y").rowsBetween(-3, 3) isa WindowSpec @test Window.partitionBy("x").orderBy("y").rangeBetween(-3, 3) isa WindowSpec @test Window.partitionBy("x").rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) isa WindowSpec @test Window.partitionBy("x").rangeBetween(Window.unboundedPreceding, Window.currentRow) isa WindowSpec end @testset "Reader/Writer" begin # for REPL: # data_dir = joinpath(@__DIR__, "test", "data") data_dir = joinpath(@__DIR__, "data") mktempdir(; prefix="spark-jl-") do tmp_dir df = spark.read.json(joinpath(data_dir, "people.json")) df.write.mode("overwrite").parquet(joinpath(tmp_dir, "people.parquet")) df = spark.read.parquet(joinpath(tmp_dir, "people.parquet")) df.write.mode("overwrite").orc(joinpath(tmp_dir, "people.orc")) df = spark.read.orc(joinpath(tmp_dir, "people.orc")) @test df.collect("name") |> Set == Set(["Peter", "Belle"]) end end @testset "Streaming" begin # for REPL: # data_dir = joinpath(@__DIR__, "test", "data") data_dir = joinpath(@__DIR__, "data") sch = StructType("name string, age long") # df = spark.readStream.schema(sch).json(joinpath(data_dir, "people.json")) df = spark.readStream.schema(sch).json(data_dir) @test df.isstreaming() query = df.writeStream. format("console"). option("numRows", 5). outputMode("append"). start() query.explain() query.explain(true) @test query.isActive() query.awaitTermination(100) query.stop() @test !query.isActive() # df = spark.readStream.schema(sch).json(data_dir) # jfew = create_instance(""" # package spark.jl; # import java.io.Serializable; # import org.apache.spark.sql.ForeachWriter; # class JuliaWriter extends ForeachWriter implements Serializable { # private static final long serialVersionUID = 1L; # @Override public boolean open(long partitionId, long version) { # return true; # } # @Override public void process(String record) { # System.out.println(record); # } # @Override public void close(Throwable errorOrNull) { # } # } # """) # query = df.writeStream.foreach(jfew).start() end