Repository: neo4j-contrib/neo4j-spark-connector Branch: 5.0 Commit: c70d2995c935 Files: 135 Total size: 2.0 MB Directory structure: gitextract_girzruyy/ ├── .commitlintrc.json ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── dependabot.yml ├── .gitignore ├── .husky/ │ ├── commit-msg │ └── pre-commit ├── .mvn/ │ └── wrapper/ │ └── maven-wrapper.properties ├── .teamcity/ │ ├── .editorconfig │ ├── builds/ │ │ ├── Build.kt │ │ ├── Common.kt │ │ ├── Empty.kt │ │ ├── JavaIntegrationTests.kt │ │ ├── Maven.kt │ │ ├── PRCheck.kt │ │ ├── Package.kt │ │ ├── PythonIntegrationTests.kt │ │ ├── Release.kt │ │ ├── SemgrepCheck.kt │ │ └── WhiteListCheck.kt │ ├── pom.xml │ └── settings.kts ├── LICENSE.txt ├── README.md ├── common/ │ ├── LICENSES.txt │ ├── NOTICE.txt │ ├── pom.xml │ └── src/ │ ├── main/ │ │ ├── java/ │ │ │ └── org/ │ │ │ └── neo4j/ │ │ │ └── spark/ │ │ │ └── util/ │ │ │ └── ReflectionUtils.java │ │ ├── resources/ │ │ │ └── neo4j-spark-connector.properties │ │ └── scala/ │ │ └── org/ │ │ └── neo4j/ │ │ └── spark/ │ │ ├── config/ │ │ │ └── TopN.scala │ │ ├── converter/ │ │ │ ├── DataConverter.scala │ │ │ └── TypeConverter.scala │ │ ├── cypher/ │ │ │ ├── Cypher5Renderer.scala │ │ │ └── CypherVersionSelector.scala │ │ ├── reader/ │ │ │ └── BasePartitionReader.scala │ │ ├── service/ │ │ │ ├── MappingService.scala │ │ │ ├── Neo4jQueryService.scala │ │ │ └── SchemaService.scala │ │ ├── streaming/ │ │ │ └── BaseStreamingPartitionReader.scala │ │ ├── util/ │ │ │ ├── DriverCache.scala │ │ │ ├── Neo4jImplicits.scala │ │ │ ├── Neo4jOptions.scala │ │ │ ├── Neo4jUtil.scala │ │ │ ├── ValidationUtil.scala │ │ │ └── Validations.scala │ │ └── writer/ │ │ ├── BaseDataWriter.scala │ │ └── DataWriterMetrics.scala │ └── test/ │ └── scala/ │ └── org/ │ └── neo4j/ │ └── spark/ │ ├── CommonTestSuiteIT.scala │ ├── CommonTestSuiteWithApocIT.scala │ ├── service/ │ │ ├── AuthenticationTest.scala │ │ ├── Neo4jQueryServiceIT.scala │ │ ├── Neo4jQueryServiceTest.scala │ │ ├── SchemaServiceTSE.scala │ │ ├── SchemaServiceTest.scala │ │ └── SchemaServiceWithApocTSE.scala │ └── util/ │ ├── DummyNamedReference.scala │ ├── Neo4jImplicitsTest.scala │ ├── Neo4jOptionsIT.scala │ ├── Neo4jOptionsTest.scala │ ├── Neo4jUtilTest.scala │ ├── ValidationsIT.scala │ └── ValidationsTest.scala ├── dangerfile.mjs ├── examples/ │ ├── neo4j_data_engineering.ipynb │ └── neo4j_data_science.ipynb ├── jreleaser.yml ├── maven-release.sh ├── mvnw ├── mvnw.cmd ├── package.json ├── pom.xml ├── scripts/ │ ├── python/ │ │ ├── requirements.txt │ │ └── test_spark.py │ └── release/ │ └── upload_to_spark_packages.sh ├── spark-3/ │ ├── LICENSES.txt │ ├── NOTICE.txt │ ├── pom.xml │ └── src/ │ ├── jreleaser/ │ │ └── assemblers/ │ │ └── zip/ │ │ └── README.txt.tpl │ ├── main/ │ │ ├── assemblies/ │ │ │ └── spark-packages-assembly.xml │ │ ├── distributions/ │ │ │ └── spark-packages.pom │ │ ├── resources/ │ │ │ ├── META-INF/ │ │ │ │ └── services/ │ │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ │ │ └── neo4j-spark-connector.properties │ │ └── scala/ │ │ └── org/ │ │ └── neo4j/ │ │ └── spark/ │ │ ├── DataSource.scala │ │ ├── Neo4jTable.scala │ │ ├── reader/ │ │ │ ├── Neo4jPartitionReader.scala │ │ │ ├── Neo4jPartitionReaderFactory.scala │ │ │ ├── Neo4jScan.scala │ │ │ └── Neo4jScanBuilder.scala │ │ ├── streaming/ │ │ │ ├── Neo4jMicroBatchReader.scala │ │ │ ├── Neo4jOffset.scala │ │ │ ├── Neo4jStreamingDataWriterFactory.scala │ │ │ ├── Neo4jStreamingPartitionReader.scala │ │ │ ├── Neo4jStreamingPartitionReaderFactory.scala │ │ │ └── Neo4jStreamingWriter.scala │ │ └── writer/ │ │ ├── Neo4jBatchWriter.scala │ │ ├── Neo4jDataWriter.scala │ │ ├── Neo4jDataWriterFactory.scala │ │ └── Neo4jWriterBuilder.scala │ └── test/ │ ├── java/ │ │ └── org/ │ │ └── neo4j/ │ │ └── spark/ │ │ ├── DataSourceReaderTypesTSE.java │ │ └── SparkConnectorSuiteIT.java │ ├── resources/ │ │ ├── log4j2.properties │ │ ├── neo4j-keycloak.jks │ │ └── neo4j-sso-test-realm.json │ └── scala/ │ └── org/ │ └── neo4j/ │ └── spark/ │ ├── DataSourceAggregationTSE.scala │ ├── DataSourceReaderNeo4jTSE.scala │ ├── DataSourceReaderNeo4jWithApocTSE.scala │ ├── DataSourceReaderTSE.scala │ ├── DataSourceReaderWithApocTSE.scala │ ├── DataSourceSchemaWriterTSE.scala │ ├── DataSourceStreamingReaderTSE.scala │ ├── DataSourceStreamingWriterTSE.scala │ ├── DataSourceWriterNeo4jSkipNullKeysTSE.scala │ ├── DataSourceWriterNeo4jTSE.scala │ ├── DataSourceWriterTSE.scala │ ├── DefaultConfigTSE.scala │ ├── GraphDataScienceIT.scala │ ├── ReauthenticationIT.scala │ ├── SparkConnector30ScalaSuiteIT.scala │ ├── SparkConnector30ScalaSuiteWithApocIT.scala │ ├── SparkConnectorAuraTest.scala │ └── TransactionTimeoutIT.scala └── test-support/ ├── pom.xml └── src/ ├── main/ │ ├── java/ │ │ └── org/ │ │ └── neo4j/ │ │ └── spark/ │ │ └── Assert.java │ ├── resources/ │ │ └── simplelogger.properties │ └── scala/ │ └── org/ │ └── neo4j/ │ ├── Closeables.scala │ ├── Neo4jContainerExtension.scala │ └── spark/ │ ├── RowUtil.scala │ ├── SparkConnectorScalaBaseTSE.scala │ ├── SparkConnectorScalaBaseWithApocTSE.scala │ ├── SparkConnectorScalaSuiteIT.scala │ ├── SparkConnectorScalaSuiteWithApocIT.scala │ ├── SparkConnectorScalaSuiteWithGdsBase.scala │ └── TestUtil.scala └── test/ └── scala/ └── org/ └── neo4j/ └── spark/ └── VersionTest.scala ================================================ FILE CONTENTS ================================================ ================================================ FILE: .commitlintrc.json ================================================ { "extends": [ "@commitlint/config-conventional" ] } ================================================ FILE: .github/CODEOWNERS ================================================ * @neo4j/team-connectors /.github/ @ali-ince @fbiville @venikkin ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve labels: bug --- ## Guidelines Please note that GitHub issues are only meant for bug reports/feature requests. If you have questions on how to use the Neo4j Connector for Apache Spark, please ask on [the Neo4j Discussion Forum](https://community.neo4j.com/c/integrations/18) instead of creating an issue here. ## Expected Behavior (Mandatory) ## Actual Behavior (Mandatory) ## How to Reproduce the Problem ### Simple Dataset (where it's possible) ``` // Insert the output of the `df.show()` call ``` ### Steps (Mandatory) 1. 1. 1. ## Screenshots (where it's possible) ## Specifications (Mandatory) Currently used versions ### Versions - Spark: - Scala: - Neo4j: - Neo4j Connector: ## Additional information * The code of the Spark job * the structure of the Dataframe * did you define the constraints/indexes? * if you're you using any Spark Cloud provider please specify it (ie: Databricks) ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project --- ## Guidelines Please note that GitHub issues are only meant for bug reports/feature requests. If you have questions on how to use the Neo4j Connector for Apache Spark, please ask on [the Neo4j Discussion Forum](https://community.neo4j.com/c/integrations/18) instead of creating an issue here. ## Feature description (Mandatory) A clear and concise description of what you want to happen. Add any considered drawbacks. ## Considered alternatives A clear and concise description of any alternative solutions or features you've considered. Maybe there is something in the project that could be reused? ## How this feature can improve the project? If you can, explain how users will be able to use this and possibly write out a version the docs. Maybe a screenshot or design? ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" cooldown: default-days: 3 - package-ecosystem: "maven" directory: "/" target-branch: "5.0" schedule: interval: "daily" cooldown: default-days: 3 - package-ecosystem: "maven" directory: "/" target-branch: "6.0" schedule: interval: "daily" cooldown: default-days: 3 ================================================ FILE: .gitignore ================================================ neo4j-home .gradle gradle/ build/ *~ \#* target out .project .classpath .settings .externalToolBuilders/ .scala_dependencies .factorypath .cache .cache-main .cache-tests *.iws *.ipr *.iml .idea .DS_Store .shell_history .mailmap .java-version .cache-main .cache-tests Thumbs.db .cache-main .cache-tests docs/guides doc/node doc/node_modules doc/package-lock.json scripts/python/local node_modules ================================================ FILE: .husky/commit-msg ================================================ #!/usr/bin/env sh npx --no -- commitlint --edit "$1" ================================================ FILE: .husky/pre-commit ================================================ #!/usr/bin/env sh ./mvnw sortpom:sort spotless:apply -f .teamcity ./mvnw sortpom:sort spotless:apply git update-index --again ================================================ FILE: .mvn/wrapper/maven-wrapper.properties ================================================ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. wrapperVersion=3.3.2 distributionType=only-script distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.9/apache-maven-3.9.9-bin.zip ================================================ FILE: .teamcity/.editorconfig ================================================ # This .editorconfig section approximates ktfmt's formatting rules. You can include it in an # existing .editorconfig file or use it standalone by copying it to /.editorconfig # and making sure your editor is set to read settings from .editorconfig files. # # It includes editor-specific config options for IntelliJ IDEA. # # If any option is wrong, PR are welcome [*] max_line_length = unset [pom.xml] max_line_length = 180 [{*.kt,*.kts}] indent_style = space insert_final_newline = true max_line_length = 100 indent_size = 2 ij_continuation_indent_size = 4 ij_java_names_count_to_use_import_on_demand = 9999 ij_kotlin_align_in_columns_case_branch = false ij_kotlin_align_multiline_binary_operation = false ij_kotlin_align_multiline_extends_list = false ij_kotlin_align_multiline_method_parentheses = false ij_kotlin_align_multiline_parameters = true ij_kotlin_align_multiline_parameters_in_calls = false ij_kotlin_allow_trailing_comma = true ij_kotlin_allow_trailing_comma_on_call_site = true ij_kotlin_assignment_wrap = normal ij_kotlin_blank_lines_after_class_header = 0 ij_kotlin_blank_lines_around_block_when_branches = 0 ij_kotlin_blank_lines_before_declaration_with_comment_or_annotation_on_separate_line = 1 ij_kotlin_block_comment_at_first_column = true ij_kotlin_call_parameters_new_line_after_left_paren = true ij_kotlin_call_parameters_right_paren_on_new_line = false ij_kotlin_call_parameters_wrap = on_every_item ij_kotlin_catch_on_new_line = false ij_kotlin_class_annotation_wrap = split_into_lines ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL ij_kotlin_continuation_indent_for_chained_calls = true ij_kotlin_continuation_indent_for_expression_bodies = true ij_kotlin_continuation_indent_in_argument_lists = true ij_kotlin_continuation_indent_in_elvis = false ij_kotlin_continuation_indent_in_if_conditions = false ij_kotlin_continuation_indent_in_parameter_lists = false ij_kotlin_continuation_indent_in_supertype_lists = false ij_kotlin_else_on_new_line = false ij_kotlin_enum_constants_wrap = off ij_kotlin_extends_list_wrap = normal ij_kotlin_field_annotation_wrap = split_into_lines ij_kotlin_finally_on_new_line = false ij_kotlin_if_rparen_on_new_line = false ij_kotlin_import_nested_classes = false ij_kotlin_insert_whitespaces_in_simple_one_line_method = true ij_kotlin_keep_blank_lines_before_right_brace = 2 ij_kotlin_keep_blank_lines_in_code = 2 ij_kotlin_keep_blank_lines_in_declarations = 2 ij_kotlin_keep_first_column_comment = true ij_kotlin_keep_indents_on_empty_lines = false ij_kotlin_keep_line_breaks = true ij_kotlin_lbrace_on_next_line = false ij_kotlin_line_comment_add_space = false ij_kotlin_line_comment_at_first_column = true ij_kotlin_method_annotation_wrap = split_into_lines ij_kotlin_method_call_chain_wrap = normal ij_kotlin_method_parameters_new_line_after_left_paren = true ij_kotlin_method_parameters_right_paren_on_new_line = true ij_kotlin_method_parameters_wrap = on_every_item ij_kotlin_name_count_to_use_star_import = 9999 ij_kotlin_name_count_to_use_star_import_for_members = 9999 ij_kotlin_parameter_annotation_wrap = off ij_kotlin_space_after_comma = true ij_kotlin_space_after_extend_colon = true ij_kotlin_space_after_type_colon = true ij_kotlin_space_before_catch_parentheses = true ij_kotlin_space_before_comma = false ij_kotlin_space_before_extend_colon = true ij_kotlin_space_before_for_parentheses = true ij_kotlin_space_before_if_parentheses = true ij_kotlin_space_before_lambda_arrow = true ij_kotlin_space_before_type_colon = false ij_kotlin_space_before_when_parentheses = true ij_kotlin_space_before_while_parentheses = true ij_kotlin_spaces_around_additive_operators = true ij_kotlin_spaces_around_assignment_operators = true ij_kotlin_spaces_around_equality_operators = true ij_kotlin_spaces_around_function_type_arrow = true ij_kotlin_spaces_around_logical_operators = true ij_kotlin_spaces_around_multiplicative_operators = true ij_kotlin_spaces_around_range = false ij_kotlin_spaces_around_relational_operators = true ij_kotlin_spaces_around_unary_operator = false ij_kotlin_spaces_around_when_arrow = true ij_kotlin_variable_annotation_wrap = off ij_kotlin_while_on_new_line = false ij_kotlin_wrap_elvis_expressions = 1 ij_kotlin_wrap_expression_body_functions = 1 ij_kotlin_wrap_first_method_in_call_chain = false ================================================ FILE: .teamcity/builds/Build.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.Project import jetbrains.buildServer.configs.kotlin.buildFeatures.notifications import jetbrains.buildServer.configs.kotlin.sequential import jetbrains.buildServer.configs.kotlin.toId class Build( name: String, forPullRequests: Boolean, javaVersions: Set, scalaVersions: Set, pysparkVersions: Set, neo4jVersions: Set, forCompatibility: Boolean = false, customizeCompletion: BuildType.() -> Unit = {} ) : Project( { this.id(name.toId()) this.name = name val complete = Empty("${name}-complete", "complete") val bts = sequential { if (forPullRequests) buildType(WhiteListCheck("${name}-whitelist-check", "white-list check")) if (forPullRequests) dependentBuildType(PRCheck("${name}-pr-check", "pr check")) parallel { scalaVersions.forEach { scala -> dependentBuildType( SemgrepCheck( "${name}-semgrep-check-${scala.version}", "semgrep check (${scala.version})", scala)) } javaVersions.cartesianProduct(scalaVersions, neo4jVersions).forEach { (java, scala, neo4j) -> sequential { val packaging = Package( "${name}-package-${java.version}-${scala.version}-${neo4j.version}", "package (${java.version}, ${scala.version}, ${neo4j.version})", java, scala, ) dependentBuildType( Maven( "${name}-build-${java.version}-${scala.version}-${neo4j.version}", "build (${java.version}, ${scala.version}, ${neo4j.version})", "test-compile", java, scala, ), ) dependentBuildType( Maven( "${name}-unit-tests-${java.version}-${scala.version}-${neo4j.version}", "unit tests (${java.version}, ${scala.version}, ${neo4j.version})", "test", java, scala, neo4j, ), ) dependentBuildType( collectArtifacts( packaging, ), ) parallel { dependentBuildType( JavaIntegrationTests( "${name}-integration-tests-java-${java.version}-${scala.version}-${neo4j.version}", "java integration tests (${java.version}, ${scala.version}, ${neo4j.version})", java, scala, neo4j, ) {}, ) pysparkVersions .filter { it.shouldTestWith(java, scala) } .forEach { pyspark -> pyspark.pythonVersions.forEach { python -> dependentBuildType( PythonIntegrationTests( "${name}-integration-tests-pyspark-${java.version}-${scala.version}-${neo4j.version}-${python.version}-${pyspark.sparkVersion.version}", "pyspark integration tests (${java.version}, ${scala.version}, ${neo4j.version}, ${python.version}, ${pyspark.sparkVersion.version})", java, python, scala, pyspark.sparkVersion, neo4j, ) { dependencies { artifacts(packaging) { artifactRules = """ +:packages/*.jar => ./scripts/python """ .trimIndent() } } }, ) } } } } } } dependentBuildType(complete) if (!forPullRequests && !forCompatibility) dependentBuildType(Release("${name}-release", "release", DEFAULT_JAVA_VERSION)) } bts.buildTypes().forEach { it.thisVcs(if (forPullRequests) "pull/*" else DEFAULT_BRANCH) it.features { loginToECR() requireDiskSpace("5gb") if (!forCompatibility) enableCommitStatusPublisher() if (forPullRequests) enablePullRequests() } buildType(it) } complete.features { notifications { branchFilter = buildString { appendLine("+:$DEFAULT_BRANCH") appendLine("+:refs/heads/$DEFAULT_BRANCH") if (forPullRequests) { appendLine("+:pull/*") appendLine("+:refs/heads/pull/*") } } queuedBuildRequiresApproval = forPullRequests buildFailedToStart = !forPullRequests buildFailed = !forPullRequests buildFinishedSuccessfully = !forPullRequests buildProbablyHanging = !forPullRequests notifierSettings = slackNotifier { connection = SLACK_CONNECTION_ID sendTo = SLACK_CHANNEL messageFormat = simpleMessageFormat() } } } complete.apply(customizeCompletion) }, ) ================================================ FILE: .teamcity/builds/Common.kt ================================================ package builds import builds.Neo4jSparkConnectorVcs.branchSpec import jetbrains.buildServer.configs.kotlin.BuildFeatures import jetbrains.buildServer.configs.kotlin.BuildSteps import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.CompoundStage import jetbrains.buildServer.configs.kotlin.FailureAction import jetbrains.buildServer.configs.kotlin.Requirements import jetbrains.buildServer.configs.kotlin.ReuseBuilds import jetbrains.buildServer.configs.kotlin.buildFeatures.PullRequests import jetbrains.buildServer.configs.kotlin.buildFeatures.buildCache import jetbrains.buildServer.configs.kotlin.buildFeatures.commitStatusPublisher import jetbrains.buildServer.configs.kotlin.buildFeatures.dockerRegistryConnections import jetbrains.buildServer.configs.kotlin.buildFeatures.freeDiskSpace import jetbrains.buildServer.configs.kotlin.buildFeatures.pullRequests import jetbrains.buildServer.configs.kotlin.buildSteps.DockerCommandStep import jetbrains.buildServer.configs.kotlin.buildSteps.MavenBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.dockerCommand import jetbrains.buildServer.configs.kotlin.buildSteps.maven import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.vcs.GitVcsRoot const val GITHUB_OWNER = "neo4j" const val GITHUB_REPOSITORY = "neo4j-spark-connector" const val DEFAULT_BRANCH = "5.0" val MAVEN_DEFAULT_ARGS = buildString { append("--no-transfer-progress ") append("--batch-mode ") append("-Dmaven.repo.local=%teamcity.build.checkoutDir%/.m2/repository ") append("-Dmaven.wagon.http.retryHandler.class=standard ") append("-Dmaven.wagon.http.retryHandler.timeout=60 ") append("-Dmaven.wagon.http.retryHandler.count=3 ") append( "-Dmaven.wagon.http.retryHandler.nonRetryableClasses=java.io.InterruptedIOException,java.net.UnknownHostException,java.net.ConnectException ") } const val SEMGREP_DOCKER_IMAGE = "semgrep/semgrep:1.146.0" const val FULL_GITHUB_REPOSITORY = "$GITHUB_OWNER/$GITHUB_REPOSITORY" const val GITHUB_URL = "https://github.com/$FULL_GITHUB_REPOSITORY" val DEFAULT_JAVA_VERSION = JavaVersion.V_11 // Look into Root Project's settings -> Connections const val SLACK_CONNECTION_ID = "PROJECT_EXT_83" const val SLACK_CHANNEL = "#team-connectors-feed" // Look into Root Project's settings -> Connections const val ECR_CONNECTION_ID = "PROJECT_EXT_124" enum class LinuxSize(val value: String) { SMALL("small"), LARGE("large") } enum class JavaVersion(val version: String, val dockerImage: String) { V_8(version = "8", dockerImage = "eclipse-temurin:8-jdk"), V_11(version = "11", dockerImage = "eclipse-temurin:11-jdk"), V_17(version = "17", dockerImage = "eclipse-temurin:17-jdk"), V_21(version = "21", dockerImage = "eclipse-temurin:21-jdk"), } enum class ScalaVersion(val version: String) { V2_12(version = "2.12"), V2_13(version = "2.13"), } enum class PythonVersion(val version: String) { V3_9(version = "3.9"), V3_10(version = "3.10"), V3_11(version = "3.11"), V3_12(version = "3.12"), V3_13(version = "3.13"), } enum class SparkVersion(val short: String, val version: String) { V3_4_4(short = "3", version = "3.4.4"), V3_5_5(short = "3", version = "3.5.5"), } enum class PySparkVersion( val sparkVersion: SparkVersion, val scalaVersion: ScalaVersion, val javaVersions: Set, val pythonVersions: Set, ) { V3_4( SparkVersion.V3_4_4, ScalaVersion.V2_12, setOf( JavaVersion.V_8, JavaVersion.V_11, JavaVersion.V_17, ), setOf( PythonVersion.V3_9, PythonVersion.V3_10, PythonVersion.V3_11, PythonVersion.V3_12, ), ), V3_5( SparkVersion.V3_5_5, ScalaVersion.V2_12, setOf( JavaVersion.V_8, JavaVersion.V_11, JavaVersion.V_17, JavaVersion.V_21, ), setOf( PythonVersion.V3_9, PythonVersion.V3_10, PythonVersion.V3_11, PythonVersion.V3_12, PythonVersion.V3_13, ), ), } fun PySparkVersion.shouldTestWith(javaVersion: JavaVersion, scalaVersion: ScalaVersion): Boolean = this.javaVersions.contains(javaVersion) && this.scalaVersion == scalaVersion enum class Neo4jVersion(val version: String, val dockerImage: String) { V_NONE("", ""), V_4_4("4.4", "neo4j:4.4-enterprise"), V_4_4_DEV( "4.4-dev", "535893049302.dkr.ecr.eu-west-1.amazonaws.com/build-service/neo4j:4.4-enterprise-debian-nightly", ), V_5("5", "neo4j:5-enterprise"), V_5_DEV( "5-dev", "535893049302.dkr.ecr.eu-west-1.amazonaws.com/build-service/neo4j:5-enterprise-debian-nightly-bundle", ), V_CALVER("2026", "neo4j:2026-enterprise"), V_CALVER_DEV( "2026-dev", "535893049302.dkr.ecr.eu-west-1.amazonaws.com/build-service/neo4j:2026-enterprise-debian-nightly-bundle", ), } fun Iterable.cartesianProduct( other1: Collection, other2: Collection ): Iterable> = this.flatMap { s -> other1.map { t -> s to t } } .flatMap { (s, t) -> other2.map { y -> Triple(s, t, y) } } object Neo4jSparkConnectorVcs : GitVcsRoot( { id("Connectors_Neo4jSparkConnector_Build") name = "neo4j-spark-connector" url = "git@github.com:neo4j/neo4j-spark-connector.git" branch = "refs/heads/$DEFAULT_BRANCH" branchSpec = "refs/heads/*" authMethod = defaultPrivateKey { userName = "git" } }, ) fun Requirements.runOnLinux(size: LinuxSize = LinuxSize.SMALL) { startsWith("cloud.amazon.agent-name-prefix", "linux-${size.value}") } fun BuildType.thisVcs(forBranch: String) = vcs { root(Neo4jSparkConnectorVcs) branchSpec = buildString { appendLine("-:*") appendLine("+:$forBranch") } cleanCheckout = true } fun BuildFeatures.enableCommitStatusPublisher() = commitStatusPublisher { vcsRootExtId = Neo4jSparkConnectorVcs.id.toString() publisher = github { githubUrl = "https://api.github.com" authType = personalToken { token = "%github-commit-status-token%" } } } fun BuildFeatures.enablePullRequests() = pullRequests { vcsRootExtId = Neo4jSparkConnectorVcs.id.toString() provider = github { authType = token { token = "%github-pull-request-token%" } filterAuthorRole = PullRequests.GitHubRoleFilter.EVERYBODY filterTargetBranch = buildString { appendLine("+:$DEFAULT_BRANCH") appendLine("+:refs/heads/$DEFAULT_BRANCH") } } } fun BuildFeatures.requireDiskSpace(size: String = "3gb") = freeDiskSpace { requiredSpace = size failBuild = true } fun BuildFeatures.loginToECR() = dockerRegistryConnections { cleanupPushedImages = true loginToRegistry = on { dockerRegistryId = ECR_CONNECTION_ID } } fun BuildFeatures.buildCache(javaVersion: JavaVersion, scalaVersion: ScalaVersion) = buildCache { this.name = "neo4j-spark-connector-${DEFAULT_BRANCH}-${javaVersion.version}-${scalaVersion.version}" publish = true use = true publishOnlyChanged = true rules = ".m2/repository" } fun CompoundStage.dependentBuildType(bt: BuildType, reuse: ReuseBuilds = ReuseBuilds.SUCCESSFUL) = buildType(bt) { onDependencyCancel = FailureAction.CANCEL onDependencyFailure = FailureAction.FAIL_TO_START reuseBuilds = reuse } fun collectArtifacts(buildType: BuildType): BuildType { buildType.artifactRules = """ +:spark-3/target/*_for_spark_*.jar => packages +:spark-3/target/*.zip => packages """ .trimIndent() return buildType } fun BuildSteps.runMaven(javaVersion: JavaVersion, init: MavenBuildStep.() -> Unit): MavenBuildStep { val maven = this.maven { dockerImagePlatform = MavenBuildStep.ImagePlatform.Linux dockerImage = javaVersion.dockerImage dockerRunParameters = "--volume /var/run/docker.sock:/var/run/docker.sock" localRepoScope = MavenBuildStep.RepositoryScope.MAVEN_DEFAULT } init(maven) return maven } fun BuildSteps.setVersion(name: String, version: String, javaVersion: JavaVersion): MavenBuildStep { return this.runMaven(javaVersion) { this.name = name goals = "versions:set" runnerArgs = "$MAVEN_DEFAULT_ARGS -Djava.version=${javaVersion.version} -DnewVersion=$version -DgenerateBackupPoms=false" } } fun BuildSteps.commitAndPush( name: String, commitMessage: String, includeFiles: String = "\\*pom.xml", dryRunParameter: String = "dry-run" ): ScriptBuildStep { return this.script { this.name = name scriptContent = """ #!/bin/bash -eu git add $includeFiles git commit -m "$commitMessage" git push """ .trimIndent() conditions { doesNotMatch(dryRunParameter, "true") } } } fun BuildSteps.pullImage(version: Neo4jVersion): DockerCommandStep = this.dockerCommand { name = "pull neo4j test image" commandType = other { subCommand = "image" commandArgs = "pull ${version.dockerImage}" } } ================================================ FILE: .teamcity/builds/Empty.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.toId class Empty(id: String, name: String) : BuildType({ this.id(id.toId()) this.name = name requirements { runOnLinux(LinuxSize.SMALL) } }) ================================================ FILE: .teamcity/builds/JavaIntegrationTests.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.toId class JavaIntegrationTests( id: String, name: String, javaVersion: JavaVersion, scalaVersion: ScalaVersion, neo4jVersion: Neo4jVersion, init: BuildType.() -> Unit ) : BuildType( { this.id(id.toId()) this.name = name init() artifactRules = """ +:diagnostics => diagnostics.zip """ .trimIndent() params { text("env.NEO4J_TEST_IMAGE", neo4jVersion.dockerImage) } steps { if (neo4jVersion != Neo4jVersion.V_NONE) { pullImage(neo4jVersion) } runMaven(javaVersion) { this.goals = "verify" this.runnerArgs = "$MAVEN_DEFAULT_ARGS -Djava.version=${javaVersion.version} -Dscala-${scalaVersion.version} -DskipUnitTests" } } features { buildCache(javaVersion, scalaVersion) } requirements { runOnLinux(LinuxSize.LARGE) } }, ) ================================================ FILE: .teamcity/builds/Maven.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.toId open class Maven( id: String, name: String, goals: String, javaVersion: JavaVersion, scalaVersion: ScalaVersion, neo4jVersion: Neo4jVersion = Neo4jVersion.V_NONE, args: String? = null ) : BuildType( { this.id(id.toId()) this.name = name params { text("env.JAVA_VERSION", javaVersion.version) text("env.NEO4J_TEST_IMAGE", neo4jVersion.dockerImage) } steps { if (neo4jVersion != Neo4jVersion.V_NONE) { pullImage(neo4jVersion) } runMaven(javaVersion) { this.goals = goals this.runnerArgs = "$MAVEN_DEFAULT_ARGS -Djava.version=${javaVersion.version} -Dscala-${scalaVersion.version} ${args ?: ""}" } } features { buildCache(javaVersion, scalaVersion) } requirements { runOnLinux(LinuxSize.SMALL) } }, ) ================================================ FILE: .teamcity/builds/PRCheck.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.buildFeatures.dockerSupport import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.toId class PRCheck(id: String, name: String) : BuildType({ this.id(id.toId()) this.name = name steps { script { scriptContent = """ #!/bin/bash set -eu export DANGER_GITHUB_API_TOKEN=%github-pull-request-token% export PULL_REQUEST_URL=https://github.com/$GITHUB_OWNER/$GITHUB_REPOSITORY/%teamcity.build.branch% # process pull request npm ci npx danger ci --verbose --failOnErrors """ .trimIndent() dockerImage = "node:18.4" dockerImagePlatform = ScriptBuildStep.ImagePlatform.Linux } } features { dockerSupport {} } requirements { runOnLinux(LinuxSize.SMALL) } }) ================================================ FILE: .teamcity/builds/Package.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.toId class Package( id: String, name: String, javaVersion: JavaVersion, scalaVersion: ScalaVersion, ) : BuildType({ this.id(id.toId()) this.name = name params { text("env.JAVA_VERSION", javaVersion.version) } steps { script { scriptContent = """ ./maven-release.sh package ${scalaVersion.version} """ .trimIndent() dockerImagePlatform = ScriptBuildStep.ImagePlatform.Linux dockerImage = javaVersion.dockerImage dockerRunParameters = "--volume /var/run/docker.sock:/var/run/docker.sock" } } features { buildCache(javaVersion, scalaVersion) } requirements { runOnLinux(LinuxSize.SMALL) } }) ================================================ FILE: .teamcity/builds/PythonIntegrationTests.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.toId class PythonIntegrationTests( id: String, name: String, javaVersion: JavaVersion, pythonVersion: PythonVersion, scalaVersion: ScalaVersion, sparkVersion: SparkVersion, neo4jVersion: Neo4jVersion, init: BuildType.() -> Unit ) : BuildType( { this.id(id.toId()) this.name = name init() artifactRules = """ +:diagnostics => diagnostics.zip """ .trimIndent() params { text("env.NEO4J_TEST_IMAGE", neo4jVersion.dockerImage) } steps { if (neo4jVersion != Neo4jVersion.V_NONE) { pullImage(neo4jVersion) } script { scriptContent = """ #!/bin/bash -eu apt-get update apt-get install -o Acquire::Retries=10 --yes build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev curl -fsSL https://pyenv.run | bash export PYENV_ROOT="${'$'}HOME/.pyenv" export PATH="${'$'}PYENV_ROOT/bin:${'$'}PATH" eval "$(pyenv init - bash)" pyenv install ${pythonVersion.version} pyenv global ${pythonVersion.version} python -m pip install --upgrade pip pip install pyspark==${sparkVersion.version} "testcontainers[neo4j]" six tzlocal==2.1 project_version="$(./mvnw help:evaluate -Dexpression="project.version" --quiet -DforceStdout)" jar_name="neo4j-connector-apache-spark_${scalaVersion.version}-${'$'}{project_version}_for_spark_${sparkVersion.short}.jar" cd ./scripts/python python test_spark.py "${'$'}{jar_name}" "${neo4jVersion.dockerImage}" """ .trimIndent() dockerImagePlatform = ScriptBuildStep.ImagePlatform.Linux dockerImage = javaVersion.dockerImage dockerRunParameters = "--volume /var/run/docker.sock:/var/run/docker.sock" } } requirements { runOnLinux(LinuxSize.SMALL) } }, ) ================================================ FILE: .teamcity/builds/Release.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.AbsoluteId import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.ParameterDisplay import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.toId private const val DRY_RUN = "dry-run" class Release(id: String, name: String, javaVersion: JavaVersion) : BuildType( { this.id(id.toId()) this.name = name templates(AbsoluteId("FetchSigningKey")) params { text( "releaseVersion", "", label = "Version to release", display = ParameterDisplay.PROMPT, allowEmpty = false, ) text( "nextSnapshotVersion", "", label = "Next snapshot version", description = "Next snapshot version to set after release", display = ParameterDisplay.PROMPT, allowEmpty = false, ) checkbox( DRY_RUN, "true", "Dry run?", description = "Whether to perform a dry run where nothing is published and released", display = ParameterDisplay.PROMPT, checked = "true", unchecked = "false", ) password("env.JRELEASER_GITHUB_TOKEN", "%github-pull-request-token%") text("env.JRELEASER_DRY_RUN", "%$DRY_RUN%") text("env.JRELEASER_PROJECT_VERSION", "%releaseVersion%") text("env.JRELEASER_ANNOUNCE_SLACK_ACTIVE", "NEVER") text("env.JRELEASER_ANNOUNCE_SLACK_TOKEN", "%slack-token%") text("env.JRELEASER_ANNOUNCE_SLACK_WEBHOOK", "%slack-webhook%") password("env.JRELEASER_GPG_PASSPHRASE", "%signing-key-passphrase%") text("env.JRELEASER_MAVENCENTRAL_USERNAME", "%publish-username%") password("env.JRELEASER_MAVENCENTRAL_TOKEN", "%publish-password%") } steps { setVersion("Set release version", "%releaseVersion%", javaVersion) commitAndPush( "Push release version", "build: release version %releaseVersion%", dryRunParameter = DRY_RUN, ) script { scriptContent = """ #!/bin/bash set -eux apt-get update apt-get install -o Acquire::Retries=10 --yes build-essential curl git unzip zip # Get the jreleaser downloader curl -sL https://raw.githubusercontent.com/jreleaser/release-action/refs/tags/2.5.0/get_jreleaser.java > get_jreleaser.java # Download JReleaser java get_jreleaser.java 1.22.0 if [ "%dry-run%" = "true" ]; then echo "we are on a dry run, only performing upload to maven central" export JRELEASER_MAVENCENTRAL_STAGE=UPLOAD export JRELEASER_ANNOUNCE_SLACK_ACTIVE=NEVER else echo "we will do a full deploy to maven central" export JRELEASER_MAVENCENTRAL_STAGE=FULL export JRELEASER_ANNOUNCE_SLACK_ACTIVE=ALWAYS fi # Execute JReleaser java -jar jreleaser-cli.jar assemble java -jar jreleaser-cli.jar full-release --debug """ .trimIndent() dockerImagePlatform = ScriptBuildStep.ImagePlatform.Linux dockerImage = javaVersion.dockerImage dockerRunParameters = "--volume /var/run/docker.sock:/var/run/docker.sock --volume %teamcity.build.checkoutDir%/signingkeysandbox:/root/.gnupg" } setVersion("Set next snapshot version", "%nextSnapshotVersion%", javaVersion) commitAndPush( "Push next snapshot version", "build: update version to %nextSnapshotVersion%", dryRunParameter = DRY_RUN, ) } artifactRules = """ +:artifacts => artifacts +:out/jreleaser => jreleaser """ .trimIndent() dependencies { artifacts(AbsoluteId("Tools_ReleaseTool")) { buildRule = lastSuccessful() artifactRules = "rt.jar => lib" } } requirements { runOnLinux(LinuxSize.SMALL) } }, ) ================================================ FILE: .teamcity/builds/SemgrepCheck.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.buildSteps.ScriptBuildStep class SemgrepCheck(id: String, name: String, scalaVersion: ScalaVersion) : Maven( id, name, "dependency:tree", JavaVersion.V_17, scalaVersion, Neo4jVersion.V_NONE, "-DoutputFile=maven_dep_tree.txt") { init { params.password("env.SEMGREP_APP_TOKEN", "%semgrep-app-token%") params.text("env.SEMGREP_REPO_NAME", FULL_GITHUB_REPOSITORY) params.text("env.SEMGREP_REPO_URL", GITHUB_URL) params.text("env.SEMGREP_BRANCH", "%teamcity.build.branch%") params.text("env.SEMGREP_JOB_URL", "%env.BUILD_URL%") params.text("env.SEMGREP_COMMIT", "%env.BUILD_VCS_NUMBER%") steps.step( ScriptBuildStep { scriptContent = "semgrep ci --no-git-ignore" dockerImagePlatform = ScriptBuildStep.ImagePlatform.Linux dockerImage = SEMGREP_DOCKER_IMAGE dockerRunParameters = "--volume /var/run/docker.sock:/var/run/docker.sock --volume %teamcity.build.checkoutDir%/signingkeysandbox:/root/.gnupg" }) } } ================================================ FILE: .teamcity/builds/WhiteListCheck.kt ================================================ package builds import jetbrains.buildServer.configs.kotlin.AbsoluteId import jetbrains.buildServer.configs.kotlin.BuildType import jetbrains.buildServer.configs.kotlin.buildSteps.script import jetbrains.buildServer.configs.kotlin.toId class WhiteListCheck(id: String, name: String) : BuildType({ this.id(id.toId()) this.name = name dependencies { artifacts(AbsoluteId("Tools_WhitelistCheck")) { buildRule = lastSuccessful() cleanDestination = true artifactRules = "whitelist-check.tar.gz!** => whitelist-check/" } } steps { script { scriptContent = """ #!/bin/bash -eu BRANCH=%teamcity.pullRequest.source.branch% if [[ "${'$'}BRANCH" =~ dependabot/.* ]]; then echo "Raised by dependabot, skipping the white list check" exit 0 fi echo "Checking committers on PR %teamcity.build.branch%" TOKEN="%github-pull-request-token%" # process pull request ./whitelist-check/bin/examine-pull-request $GITHUB_OWNER $GITHUB_REPOSITORY "${'$'}{TOKEN}" %teamcity.build.branch% whitelist-check/cla-database.csv """ .trimIndent() formatStderrAsError = true } } requirements { runOnLinux(LinuxSize.SMALL) } }) ================================================ FILE: .teamcity/pom.xml ================================================ 4.0.0 org.jetbrains.teamcity configs-dsl-kotlin-parent 1.0-SNAPSHOT Connectors_Neo4jSparkConnector teamcity-pipeline 1.0-SNAPSHOT Connectors_Neo4jSparkConnector Config DSL Script UTF-8 4.0.0 2.40.0 org.jetbrains.kotlin kotlin-script-runtime ${kotlin.version} compile org.jetbrains.kotlin kotlin-stdlib-jdk8 ${kotlin.version} compile org.jetbrains.teamcity configs-dsl-kotlin-latest ${teamcity.dsl.version} compile org.jetbrains.teamcity configs-dsl-kotlin-plugins-latest 1.0-SNAPSHOT pom compile true jetbrains-all https://download.jetbrains.com/teamcity-repository true teamcity-server https://live.neo4j-build.io/app/dsl-plugins-repository JetBrains https://download.jetbrains.com/teamcity-repository org.jetbrains.kotlin kotlin-maven-plugin ${kotlin.version} compile compile process-sources test-compile test-compile process-test-sources org.jetbrains.teamcity teamcity-configs-maven-plugin ${teamcity.dsl.version} kotlin target/generated-configs com.github.ekryd.sortpom sortpom-maven-plugin ${sortpom-maven-plugin.version} ${project.build.sourceEncoding} false schemaLocation 4 true scope,groupId,artifactId false false verify validate STOP com.diffplug.spotless spotless-maven-plugin ${spotless-maven-plugin.version} **/*.kt **/*.kts 0.46 2 4 true 100 check compile ${basedir} ================================================ FILE: .teamcity/settings.kts ================================================ import builds.Build import builds.DEFAULT_BRANCH import builds.JavaVersion import builds.Neo4jSparkConnectorVcs import builds.Neo4jVersion import builds.PySparkVersion import builds.ScalaVersion import jetbrains.buildServer.configs.kotlin.Project import jetbrains.buildServer.configs.kotlin.failureConditions.BuildFailureOnText import jetbrains.buildServer.configs.kotlin.failureConditions.failOnText import jetbrains.buildServer.configs.kotlin.project import jetbrains.buildServer.configs.kotlin.triggers.schedule import jetbrains.buildServer.configs.kotlin.triggers.vcs import jetbrains.buildServer.configs.kotlin.version version = "2025.11" project { params { text("default-spark-branch", DEFAULT_BRANCH) text("osssonatypeorg-username", "%publish-username%") password("osssonatypeorg-password", "%publish-password%") password("signing-key-passphrase", "%publish-signing-key-password%") password("github-commit-status-token", "%github-token%") password("github-pull-request-token", "%github-token%") password("semgrep-app-token", "%semgrep-token%") } vcsRoot(Neo4jSparkConnectorVcs) subProject( Build( name = "main", javaVersions = setOf(JavaVersion.V_8, JavaVersion.V_11, JavaVersion.V_17, JavaVersion.V_21), scalaVersions = setOf(ScalaVersion.V2_12, ScalaVersion.V2_13), pysparkVersions = setOf(PySparkVersion.V3_4, PySparkVersion.V3_5), neo4jVersions = setOf(Neo4jVersion.V_4_4, Neo4jVersion.V_5, Neo4jVersion.V_CALVER), forPullRequests = false, ) { triggers { vcs { this.branchFilter = buildString { appendLine("+:$DEFAULT_BRANCH") appendLine("+:refs/heads/$DEFAULT_BRANCH") } this.triggerRules = """ -:comment=^build.*release version.*:** -:comment=^build.*update version.*:** """ .trimIndent() } } }, ) subProject( Build( name = "pull-request", javaVersions = setOf(JavaVersion.V_8, JavaVersion.V_11, JavaVersion.V_17), scalaVersions = setOf(ScalaVersion.V2_12, ScalaVersion.V2_13), pysparkVersions = setOf(PySparkVersion.V3_5), neo4jVersions = setOf(Neo4jVersion.V_4_4, Neo4jVersion.V_5, Neo4jVersion.V_CALVER), forPullRequests = true, ) { triggers { vcs { this.branchFilter = buildString { appendLine("+:pull/*") appendLine("+:refs/heads/pull/*") } } } // when a PR gets closed, TC falls back to main branch to run the pipeline, which we don't // want failureConditions { failOnText { conditionType = BuildFailureOnText.ConditionType.CONTAINS pattern = "which does not correspond to any branch monitored by the build VCS roots" failureMessage = "Error: The branch %teamcity.build.branch% does not exist" reverse = false stopBuildOnFailure = true } } }, ) subProject( Project { this.id("compatibility") name = "compatibility" Neo4jVersion.entries.minus(Neo4jVersion.V_NONE).forEach { neo4j -> subProject( Build( name = neo4j.version, javaVersions = setOf(JavaVersion.V_8, JavaVersion.V_11, JavaVersion.V_17, JavaVersion.V_21), scalaVersions = setOf(ScalaVersion.V2_12, ScalaVersion.V2_13), pysparkVersions = setOf(PySparkVersion.V3_4, PySparkVersion.V3_5), neo4jVersions = setOf(neo4j), forPullRequests = false, forCompatibility = true, ) { triggers { vcs { enabled = false } schedule { branchFilter = buildString { appendLine("+:$DEFAULT_BRANCH") appendLine("+:refs/heads/$DEFAULT_BRANCH") } schedulingPolicy = daily { hour = 7 minute = 0 } triggerBuild = always() withPendingChangesOnly = false enforceCleanCheckout = true enforceCleanCheckoutForDependencies = true } } }, ) } }, ) } ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Neo4j Connector for Apache Spark This repository contains the Neo4j Connector for Apache Spark. ## License This neo4j-connector-apache-spark is Apache 2 Licensed ## Documentation The documentation for Neo4j Connector for Apache Spark lives at https://github.com/neo4j/docs-spark repository. ## Building for Spark 3 You can build for Spark 3.x with both Scala 2.12 and Scala 2.13 ``` ./maven-release.sh package 2.12 ./maven-release.sh package 2.13 ``` These commands will generate the corresponding targets * `spark-3/target/neo4j-connector-apache-spark_2.12-_for_spark_3.jar` * `spark-3/target/neo4j-connector-apache-spark_2.13-_for_spark_3.jar` ## Integration with Apache Spark Applications **spark-shell, pyspark, or spark-submit** `$SPARK_HOME/bin/spark-shell --jars neo4j-connector-apache-spark_2.12-_for_spark_3.jar` `$SPARK_HOME/bin/spark-shell --packages org.neo4j:neo4j-connector-apache-spark_2.12:_for_spark_3` **sbt** If you use the [sbt-spark-package plugin](https://github.com/databricks/sbt-spark-package), in your sbt build file, add: ```scala resolvers += "Spark Packages Repo" at "http://dl.bintray.com/spark-packages/maven" libraryDependencies += "org.neo4j" % "neo4j-connector-apache-spark_2.12" % "_for_spark_3" ``` **maven** In your pom.xml, add: ```xml org.neo4j neo4j-connector-apache-spark_2.12 [version]_for_spark_3 ``` For more info about the available version visit https://neo4j.com/developer/spark/overview/#_compatibility ================================================ FILE: common/LICENSES.txt ================================================ This file contains the full license text of the included third party libraries. For an overview of the licenses see the NOTICE.txt file. ------------------------------------------------------------------------------ Apache Software License, Version 2.0 IntelliJ IDEA Annotations Kotlin Stdlib Netty/Buffer Netty/Codec Netty/Common Netty/Handler Netty/Resolver Netty/TomcatNative [OpenSSL - Classes] Netty/Transport Netty/Transport/Native/Unix/Common Non-Blocking Reactive Foundation for the JVM org.apiguardian:apiguardian-api ------------------------------------------------------------------------------ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ------------------------------------------------------------------------------ MIT License SLF4J API Module ------------------------------------------------------------------------------ The MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ------------------------------------------------------------------------------ MIT No Attribution License reactive-streams ------------------------------------------------------------------------------ MIT No Attribution Copyright Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: common/NOTICE.txt ================================================ Copyright (c) "Neo4j" Neo4j Sweden AB [https://neo4j.com] This file is part of Neo4j. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Full license texts are found in LICENSES.txt. Third-party licenses -------------------- Apache Software License, Version 2.0 IntelliJ IDEA Annotations Kotlin Stdlib Netty/Buffer Netty/Codec Netty/Common Netty/Handler Netty/Resolver Netty/TomcatNative [OpenSSL - Classes] Netty/Transport Netty/Transport/Native/Unix/Common Non-Blocking Reactive Foundation for the JVM org.apiguardian:apiguardian-api MIT License SLF4J API Module MIT No Attribution License reactive-streams ================================================ FILE: common/pom.xml ================================================ 4.0.0 org.neo4j neo4j-connector-apache-spark_parent 5.4.3-SNAPSHOT neo4j-connector-apache-spark_common jar neo4j-connector-apache-spark-common Common Services for Neo4j Connector for Apache Spark using the binary Bolt Driver org.neo4j caniuse-core org.neo4j caniuse-neo4j-detection org.neo4j neo4j-cypher-dsl org.neo4j.connectors commons-authn-spi org.neo4j.connectors commons-reauth-driver org.neo4j.driver neo4j-java-driver-slim org.apache.spark spark-core_${scala.binary.version} provided org.apache.spark spark-sql_${scala.binary.version} provided org.scala-lang scala-library provided org.scala-lang scala-reflect provided org.neo4j.connectors commons-authn-provided runtime org.neo4j neo4j-connector-apache-spark_test-support ${project.version} test org.scalatest scalatest_${scala.binary.version} test org.scalatestplus junit-4-13_${scala.binary.version} test pl.pragmatists JUnitParams 1.1.1 test true src/main/resources net.alchim31.maven scala-maven-plugin org.apache.maven.plugins maven-failsafe-plugin org.apache.maven.plugins maven-surefire-plugin ================================================ FILE: common/src/main/java/org/neo4j/spark/util/ReflectionUtils.java ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.util.Optional; import java.util.stream.Stream; public class ReflectionUtils { private static final MethodHandles.Lookup lookup = MethodHandles.lookup(); private static Optional getGroupByColumns() { try { return Optional.of(lookup .findVirtual(Aggregation.class, "groupByColumns", MethodType.methodType(NamedReference[].class)) .asType(MethodType.methodType(Expression[].class, Aggregation.class))); } catch (Exception e) { return Optional.empty(); } } private static Optional getGroupByExpressions() { try { return Optional.of(lookup .findVirtual(Aggregation.class, "groupByExpressions", MethodType.methodType(Expression[].class))); } catch (Exception e) { return Optional.empty(); } } private static final Optional groupByColumns = getGroupByColumns(); private static final Optional groupByExpressions = getGroupByExpressions(); private static final Expression[] EMPTY = new Expression[0]; public static Expression[] groupByCols(Aggregation agg) { return Stream.of(groupByExpressions, groupByColumns) .filter(Optional::isPresent) .map(Optional::get) .map(mh -> { try { return (Expression[]) mh.invokeExact(agg); } catch (Throwable e) { return EMPTY; } }) .findFirst() .orElse(EMPTY); } } ================================================ FILE: common/src/main/resources/neo4j-spark-connector.properties ================================================ version=${project.version} ================================================ FILE: common/src/main/scala/org/neo4j/spark/config/TopN.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.config import org.apache.spark.sql.connector.expressions.SortOrder case class TopN(limit: Long, orders: Array[SortOrder] = Array.empty) ================================================ FILE: common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.converter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.MapData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.neo4j.driver.Value import org.neo4j.driver.Values import org.neo4j.driver.internal._ import org.neo4j.driver.types.IsoDuration import org.neo4j.driver.types.Node import org.neo4j.driver.types.Relationship import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import java.time._ import java.time.format.DateTimeFormatter import scala.annotation.tailrec import scala.collection.JavaConverters._ trait DataConverter[T] { def convert(value: Any, dataType: DataType = null): T @tailrec final private[converter] def extractStructType(dataType: DataType): StructType = dataType match { case structType: StructType => structType case mapType: MapType => extractStructType(mapType.valueType) case arrayType: ArrayType => extractStructType(arrayType.elementType) case _ => throw new UnsupportedOperationException(s"$dataType not supported") } } object SparkToNeo4jDataConverter { def apply(options: Neo4jOptions): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter(options) private def dayTimeMicrosToNeo4jDuration(micros: Long): Value = { val oneSecondInMicros = 1000000L val oneDayInMicros = 24 * 3600 * oneSecondInMicros val numberDays = Math.floorDiv(micros, oneDayInMicros) val remainderMicros = Math.floorMod(micros, oneDayInMicros) val numberSeconds = Math.floorDiv(remainderMicros, oneSecondInMicros) val numberNanos = Math.floorMod(remainderMicros, oneSecondInMicros) * 1000 Values.isoDuration(0L, numberDays, numberSeconds, numberNanos.toInt) } // while Neo4j supports years, this driver version's API does not expose it. private def yearMonthIntervalToNeo4jDuration(months: Int): Value = { Values.isoDuration(months.toLong, 0L, 0L, 0) } } class SparkToNeo4jDataConverter(options: Neo4jOptions) extends DataConverter[Value] { override def convert(value: Any, dataType: DataType): Value = { value match { case date: java.sql.Date => convert(date.toLocalDate, dataType) case timestamp: java.sql.Timestamp => if (options.legacyTypeConversionEnabled) { convert(timestamp.toLocalDateTime, dataType) } else { convert(timestamp.toInstant.atZone(ZoneOffset.UTC), dataType) } case intValue: Int if dataType == DataTypes.DateType => convert( DateTimeUtils .toJavaDate(intValue), dataType ) case intValue: Int if dataType.isInstanceOf[YearMonthIntervalType] && !options.legacyTypeConversionEnabled => SparkToNeo4jDataConverter.yearMonthIntervalToNeo4jDuration(intValue) case longValue: Long if dataType == DataTypes.TimestampType => convert(DateTimeUtils.toJavaTimestamp(longValue), dataType) case longValue: Long if dataType == DataTypes.TimestampNTZType && !options.legacyTypeConversionEnabled => convert(DateTimeUtils.microsToLocalDateTime(longValue), dataType) case longValue: Long if dataType.isInstanceOf[DayTimeIntervalType] && !options.legacyTypeConversionEnabled => SparkToNeo4jDataConverter.dayTimeMicrosToNeo4jDuration(longValue) case unsafeRow: UnsafeRow => { val structType = extractStructType(dataType) val row = new GenericRowWithSchema(unsafeRow.toSeq(structType).toArray, structType) convert(row) } case struct: GenericRow => { def toMap(struct: GenericRow): Value = { Values.value( struct.schema.fields.map(f => f.name -> convert(struct.getAs(f.name), f.dataType)).toMap.asJava ) } try { struct.getAs[UTF8String]("type").toString match { case SchemaService.POINT_TYPE_2D => Values.point( struct.getAs[Number]("srid").intValue(), struct.getAs[Number]("x").doubleValue(), struct.getAs[Number]("y").doubleValue() ) case SchemaService.POINT_TYPE_3D => Values.point( struct.getAs[Number]("srid").intValue(), struct.getAs[Number]("x").doubleValue(), struct.getAs[Number]("y").doubleValue(), struct.getAs[Number]("z").doubleValue() ) case SchemaService.DURATION_TYPE => Values.isoDuration( struct.getAs[Number]("months").longValue(), struct.getAs[Number]("days").longValue(), struct.getAs[Number]("seconds").longValue(), struct.getAs[Number]("nanoseconds").intValue() ) case SchemaService.TIME_TYPE_OFFSET => Values.value(OffsetTime.parse(struct.getAs[UTF8String]("value").toString)) case SchemaService.TIME_TYPE_LOCAL => Values.value(LocalTime.parse(struct.getAs[UTF8String]("value").toString)) case _ => toMap(struct) } } catch { case _: Throwable => toMap(struct) } } case unsafeArray: ArrayData => { val sparkType = dataType match { case arrayType: ArrayType => arrayType.elementType case _ => dataType } if (sparkType == DataTypes.ByteType && !options.legacyTypeConversionEnabled) { Values.value(unsafeArray.toByteArray) } else { val javaList = unsafeArray.toSeq[AnyRef](sparkType) .map(elem => convert(elem, sparkType)) .asJava Values.value(javaList) } } case unsafeMapData: MapData => { // Neo4j only supports Map[String, AnyRef] val mapType = dataType.asInstanceOf[MapType] val map: Map[String, AnyRef] = (0 until unsafeMapData.numElements()) .map(i => (unsafeMapData.keyArray().getUTF8String(i).toString, unsafeMapData.valueArray().get(i, mapType.valueType)) ) .toMap[String, AnyRef] .mapValues(innerValue => convert(innerValue, mapType.valueType)) .toMap[String, AnyRef] Values.value(map.asJava) } case string: UTF8String => convert(string.toString) case decimal: Decimal if dataType.isInstanceOf[DecimalType] => Values.value(decimal.toString) case _ => Values.value(value) } } } object Neo4jToSparkDataConverter { def apply(options: Neo4jOptions): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter(options) } class Neo4jToSparkDataConverter(options: Neo4jOptions) extends DataConverter[Any] { override def convert(value: Any, dataType: DataType): Any = { if (dataType != null && dataType == DataTypes.StringType && value != null && !value.isInstanceOf[String]) { convert(Neo4jUtil.mapper.writeValueAsString(value), dataType) } else { value match { case node: Node => { val map = node.asMap() val structType = extractStructType(dataType) val fields = structType .filter(field => field.name != Neo4jUtil.INTERNAL_ID_FIELD && field.name != Neo4jUtil.INTERNAL_LABELS_FIELD) .map(field => convert(map.get(field.name), field.dataType)) InternalRow.fromSeq(Seq(convert(node.id()), convert(node.labels())) ++ fields) } case rel: Relationship => { val map = rel.asMap() val structType = extractStructType(dataType) val fields = structType .filter(field => field.name != Neo4jUtil.INTERNAL_REL_ID_FIELD && field.name != Neo4jUtil.INTERNAL_REL_TYPE_FIELD && field.name != Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD && field.name != Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD ) .map(field => convert(map.get(field.name), field.dataType)) InternalRow.fromSeq(Seq( convert(rel.id()), convert(rel.`type`()), convert(rel.startNodeId()), convert(rel.endNodeId()) ) ++ fields) } case d: IsoDuration => { val months = d.months() val days = d.days() val nanoseconds: Integer = d.nanoseconds() val seconds = d.seconds() InternalRow.fromSeq(Seq( UTF8String.fromString(SchemaService.DURATION_TYPE), months, days, seconds, nanoseconds, UTF8String.fromString(d.toString) )) } case zt: ZonedDateTime => DateTimeUtils.instantToMicros(zt.toInstant) case dt: LocalDateTime => { if (options.legacyTypeConversionEnabled) { DateTimeUtils.instantToMicros(dt.toInstant(ZoneOffset.UTC)) } else { DateTimeUtils.localDateTimeToMicros(dt) } } case d: LocalDate => d.toEpochDay.toInt case lt: LocalTime => { InternalRow.fromSeq(Seq( UTF8String.fromString(SchemaService.TIME_TYPE_LOCAL), UTF8String.fromString(lt.format(DateTimeFormatter.ISO_TIME)) )) } case t: OffsetTime => { InternalRow.fromSeq(Seq( UTF8String.fromString(SchemaService.TIME_TYPE_OFFSET), UTF8String.fromString(t.format(DateTimeFormatter.ISO_TIME)) )) } case p: InternalPoint2D => { val srid: Integer = p.srid() InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_2D), srid, p.x(), p.y(), null)) } case p: InternalPoint3D => { val srid: Integer = p.srid() InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_3D), srid, p.x(), p.y(), p.z())) } case l: java.util.List[_] => { val elementType = if (dataType != null) dataType.asInstanceOf[ArrayType].elementType else null ArrayData.toArrayData(l.asScala.map(e => convert(e, elementType)).toArray) } case map: java.util.Map[_, _] => { if (dataType != null) { val mapType = dataType.asInstanceOf[MapType] ArrayBasedMapData(map.asScala.map(t => (convert(t._1, mapType.keyType), convert(t._2, mapType.valueType)))) } else { ArrayBasedMapData(map.asScala.map(t => (convert(t._1), convert(t._2)))) } } case s: String => UTF8String.fromString(s) case _ => value } } } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.converter import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.DayTimeIntervalType import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.types.YearMonthIntervalType import org.neo4j.driver.types.Entity import org.neo4j.spark.converter.CypherToSparkTypeConverter.cleanTerms import org.neo4j.spark.converter.CypherToSparkTypeConverter.durationType import org.neo4j.spark.converter.CypherToSparkTypeConverter.pointType import org.neo4j.spark.converter.CypherToSparkTypeConverter.timeType import org.neo4j.spark.converter.SparkToCypherTypeConverter.mapping import org.neo4j.spark.service.SchemaService.normalizedClassName import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits import org.neo4j.spark.util.Neo4jOptions import scala.collection.JavaConverters._ trait TypeConverter[SOURCE_TYPE, DESTINATION_TYPE] { def convert(sourceType: SOURCE_TYPE, value: Any = null): DESTINATION_TYPE } object CypherToSparkTypeConverter { def apply(options: Neo4jOptions): CypherToSparkTypeConverter = new CypherToSparkTypeConverter(options) private val cleanTerms: String = "Unmodifiable|Internal|Iso|2D|3D|Offset" val durationType: DataType = DataTypes.createStructType(Array( DataTypes.createStructField("type", DataTypes.StringType, false), DataTypes.createStructField("months", DataTypes.LongType, false), DataTypes.createStructField("days", DataTypes.LongType, false), DataTypes.createStructField("seconds", DataTypes.LongType, false), DataTypes.createStructField("nanoseconds", DataTypes.IntegerType, false), DataTypes.createStructField("value", DataTypes.StringType, false) )) val pointType: DataType = DataTypes.createStructType(Array( DataTypes.createStructField("type", DataTypes.StringType, false), DataTypes.createStructField("srid", DataTypes.IntegerType, false), DataTypes.createStructField("x", DataTypes.DoubleType, false), DataTypes.createStructField("y", DataTypes.DoubleType, false), DataTypes.createStructField("z", DataTypes.DoubleType, true) )) val timeType: DataType = DataTypes.createStructType(Array( DataTypes.createStructField("type", DataTypes.StringType, false), DataTypes.createStructField("value", DataTypes.StringType, false) )) } class CypherToSparkTypeConverter(options: Neo4jOptions) extends TypeConverter[String, DataType] { override def convert(sourceType: String, value: Any = null): DataType = { var cleanedSourceType = sourceType.replaceAll(cleanTerms, "") if (options.legacyTypeConversionEnabled) { cleanedSourceType = cleanedSourceType.replaceAll("Local|Zoned", "") } cleanedSourceType match { case "Node" | "Relationship" => if (value != null) value.asInstanceOf[Entity].toStruct(options) else DataTypes.NullType case "NodeArray" | "RelationshipArray" => if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct(options)) else DataTypes.NullType case "Boolean" => DataTypes.BooleanType case "Long" => DataTypes.LongType case "Double" => DataTypes.DoubleType case "Point" => pointType case "DateTime" | "ZonedDateTime" => DataTypes.TimestampType case "LocalDateTime" => if (options.legacyTypeConversionEnabled) { DataTypes.TimestampType } else { DataTypes.TimestampNTZType } case "Time" | "LocalTime" => timeType case "Date" | "LocalDate" => DataTypes.DateType case "Duration" => durationType case "ByteArray" => DataTypes.BinaryType case "Map" => { val valueType = if (value == null) { DataTypes.NullType } else { val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala val types = map.values .map(value => normalizedClassName(value, options)) .toSet if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType } DataTypes.createMapType(DataTypes.StringType, valueType) } case "Array" => { val valueType = if (value == null) { DataTypes.NullType } else { val list = value.asInstanceOf[java.util.List[AnyRef]].asScala val types = list .map(value => normalizedClassName(value, options)) .toSet if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType } DataTypes.createArrayType(valueType) } // These are from APOC case "StringArray" => DataTypes.createArrayType(DataTypes.StringType) case "LongArray" => DataTypes.createArrayType(DataTypes.LongType) case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType) case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType) case "PointArray" => DataTypes.createArrayType(pointType) case "DateTimeArray" | "ZonedDateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType) case "TimeArray" | "LocalTimeArray" => DataTypes.createArrayType(timeType) case "DateArray" | "LocalDateArray" => DataTypes.createArrayType(DataTypes.DateType) case "DurationArray" => DataTypes.createArrayType(durationType) // Default is String case _ => DataTypes.StringType } } } object SparkToCypherTypeConverter { def apply(options: Neo4jOptions): SparkToCypherTypeConverter = new SparkToCypherTypeConverter(options) private val baseMappings: Map[DataType, String] = Map( DataTypes.BooleanType -> "BOOLEAN", DataTypes.StringType -> "STRING", DecimalType.SYSTEM_DEFAULT -> "STRING", DataTypes.ByteType -> "INTEGER", DataTypes.ShortType -> "INTEGER", DataTypes.IntegerType -> "INTEGER", DataTypes.LongType -> "INTEGER", DataTypes.FloatType -> "FLOAT", DataTypes.DoubleType -> "FLOAT", DataTypes.DateType -> "DATE", durationType -> "DURATION", pointType -> "POINT", // Cypher graph entities do not allow null values in arrays DataTypes.createArrayType(DataTypes.BooleanType, false) -> "LIST", DataTypes.createArrayType(DataTypes.StringType, false) -> "LIST", DataTypes.createArrayType(DecimalType.SYSTEM_DEFAULT, false) -> "LIST", DataTypes.createArrayType(DataTypes.ShortType, false) -> "LIST", DataTypes.createArrayType(DataTypes.IntegerType, false) -> "LIST", DataTypes.createArrayType(DataTypes.LongType, false) -> "LIST", DataTypes.createArrayType(DataTypes.FloatType, false) -> "LIST", DataTypes.createArrayType(DataTypes.DoubleType, false) -> "LIST", DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST", DataTypes.createArrayType(durationType, false) -> "LIST", DataTypes.createArrayType(pointType, false) -> "LIST" ) private def mapping(sourceType: DataType, options: Neo4jOptions): String = { val mappings = sourceTypeMappings(options) mappings(sourceType) } private def sourceTypeMappings(options: Neo4jOptions): Map[DataType, String] = { var result = baseMappings if (options.legacyTypeConversionEnabled) { result += (DataTypes.TimestampType -> "LOCAL DATETIME") result += (DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST") result += (DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST") } else { result += (DataTypes.TimestampType -> "ZONED DATETIME") result += (DataTypes.TimestampNTZType -> "LOCAL DATETIME") result += (DayTimeIntervalType() -> "DURATION") result += (YearMonthIntervalType() -> "DURATION") result += (DataTypes.createArrayType(DataTypes.ByteType, false) -> "ByteArray") result += (DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST") result += (DataTypes.createArrayType(DataTypes.TimestampNTZType, false) -> "LIST") result += (DataTypes.createArrayType(DayTimeIntervalType(), false) -> "LIST") result += (DataTypes.createArrayType(DayTimeIntervalType(), true) -> "LIST") result += (DataTypes.createArrayType(YearMonthIntervalType(), false) -> "LIST") result += (DataTypes.createArrayType(YearMonthIntervalType(), true) -> "LIST") } result } } class SparkToCypherTypeConverter(options: Neo4jOptions) extends TypeConverter[DataType, String] { override def convert(sourceType: DataType, value: Any): String = mapping(sourceType, options) } ================================================ FILE: common/src/main/scala/org/neo4j/spark/cypher/Cypher5Renderer.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.cypher import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jVersion import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.renderer.Configuration import org.neo4j.cypherdsl.core.renderer.Dialect import org.neo4j.cypherdsl.core.renderer.Renderer import org.neo4j.spark.cypher.Cypher5Renderer.Neo4jV5 import org.neo4j.spark.cypher.CypherVersionSelector.selectCypherVersionClause class Cypher5Renderer(neo4j: Neo4j) extends Renderer { private val delegate = Renderer.getRenderer( Configuration.newConfig() .withDialect( if (neo4j.getVersion.compareTo(Neo4jV5) < 0) { Dialect.DEFAULT } else { Dialect.NEO4J_5 } ) .build() ) override def render(statement: Statement): String = { val rendered = delegate.render(statement) s"${selectCypherVersionClause(neo4j)}$rendered" } } private object Cypher5Renderer { private val Neo4jV5 = new Neo4jVersion(5, 0, 0) } ================================================ FILE: common/src/main/scala/org/neo4j/spark/cypher/CypherVersionSelector.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.cypher import org.neo4j.caniuse.CanIUse.INSTANCE.canIUse import org.neo4j.caniuse.Cypher.{INSTANCE => Cypher} import org.neo4j.caniuse.Neo4j object CypherVersionSelector { def selectCypherVersionClause(neo4j: Neo4j): String = { if (canIUse(Cypher.explicitCypher5Selection()).withNeo4j(neo4j)) { "CYPHER 5 " } else { "" } } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/reader/BasePartitionReader.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.reader import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.driver.Record import org.neo4j.driver.Session import org.neo4j.driver.Transaction import org.neo4j.driver.Values import org.neo4j.spark.service.MappingService import org.neo4j.spark.service.Neo4jQueryReadStrategy import org.neo4j.spark.service.Neo4jQueryService import org.neo4j.spark.service.Neo4jQueryStrategy import org.neo4j.spark.service.Neo4jReadMappingStrategy import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.QueryType import java.io.IOException import java.time.Duration import java.util import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.locks.LockSupport import scala.collection.JavaConverters._ abstract class BasePartitionReader( private val neo4j: Neo4j, private val options: Neo4jOptions, private val filters: Array[Filter], private val schema: StructType, private val jobId: String, private val partitionSkipLimit: PartitionPagination, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val requiredColumns: StructType, private val aggregateColumns: Array[AggregateFunc] ) extends Logging { private var result: Iterator[Record] = _ private var session: Session = _ private var transaction: Transaction = _ protected val name: String = if (partitionSkipLimit.partitionNumber > 0) s"$jobId-${partitionSkipLimit.partitionNumber}" else jobId protected val driverCache: DriverCache = new DriverCache(options.connection) private var nextRow: InternalRow = _ private lazy val values = { val params = new java.util.HashMap[String, Any]() params.put(Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT, scriptResult) Neo4jUtil.paramsFromFilters(filters) .foreach(p => params.put(p._1, p._2)) if (options.query.queryType == QueryType.GDS) { params.putAll(options.gdsMetadata.parameters) } params } private val mappingService = new MappingService(new Neo4jReadMappingStrategy(options, requiredColumns), options) @volatile private var error: Boolean = false private val retries = new AtomicInteger(options.transactionSettings.retries) @throws(classOf[IOException]) def next: Boolean = try { nextHandler() } catch { case t: Throwable => if (options.transactionSettings.shouldFailOn(t)) { error = true logError("Error while invoking next due to explicitly configured failure condition:", t) throw new IOException(t) } if (Neo4jUtil.isRetryableException(t) && retries.get() > 0) { val currentRetry = retries.decrementAndGet logInfo( s"encountered a transient exception while reading, retrying ${options.transactionSettings.retries - currentRetry} time(s)", t ) close() result = null // Reset result to force new query // Wait before retry LockSupport.parkNanos(Duration.ofMillis(options.transactionSettings.retryTimeout).toNanos) next } else { error = true logError("Error while invoking next:", t) throw new IOException(t) } } private def nextHandler(): Boolean = { if (result == null) { session = driverCache.getOrCreate().session(options.session.toNeo4jSession()) transaction = session.beginTransaction(options.toNeo4jTransactionConfig) val queryText = query() val queryParams = queryParameters logInfo(s"Running the following query on Neo4j: $queryText") logDebug(s"with parameters $queryParams") result = transaction.run(queryText, Values.value(queryParams)) .asScala } if (result.hasNext) { nextRow = mappingService.convert(result.next(), schema) true } else { false } } def get: InternalRow = nextRow def close(): Unit = { Neo4jUtil.closeSafely(transaction, log) Neo4jUtil.closeSafely(session, log) driverCache.close() } def hasError(): Boolean = error protected def query(): String = { new Neo4jQueryService( options, new Neo4jQueryReadStrategy( neo4j, filters, partitionSkipLimit, requiredColumns.fieldNames, aggregateColumns, jobId ) ) .createQuery() } protected def queryParameters: util.Map[String, Any] = values } ================================================ FILE: common/src/main/scala/org/neo4j/spark/service/MappingService.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType import org.neo4j.driver.Record import org.neo4j.driver.Value import org.neo4j.driver.Values import org.neo4j.driver.internal.value.MapValue import org.neo4j.driver.types.Node import org.neo4j.spark.converter.Neo4jToSparkDataConverter import org.neo4j.spark.converter.SparkToNeo4jDataConverter import org.neo4j.spark.service.Neo4jWriteMappingStrategy.KEYS import org.neo4j.spark.service.Neo4jWriteMappingStrategy.PROPERTIES import org.neo4j.spark.util.Neo4jImplicits._ import org.neo4j.spark.util.Neo4jNodeMetadata import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.QueryType import org.neo4j.spark.util.RelationshipSaveStrategy import java.util import java.util.function import java.util.function.BiConsumer import scala.collection.JavaConverters._ import scala.collection.mutable class Neo4jWriteMappingStrategy(private val options: Neo4jOptions) extends Neo4jMappingStrategy[InternalRow, Option[java.util.Map[String, AnyRef]]] with Logging { private val dataConverter = SparkToNeo4jDataConverter(options) override def node(row: InternalRow, schema: StructType): Option[java.util.Map[String, AnyRef]] = { val rowMap: java.util.Map[String, Object] = new java.util.HashMap[String, Object] val keys: java.util.Map[String, Object] = new java.util.HashMap[String, Object] val properties: java.util.Map[String, Object] = new java.util.HashMap[String, Object] rowMap.put(KEYS, keys) rowMap.put(PROPERTIES, properties) query(row, schema) .get .forEach(new BiConsumer[String, AnyRef] { override def accept(key: String, value: AnyRef): Unit = if (options.nodeMetadata.nodeKeys.contains(key)) { keys.put(options.nodeMetadata.nodeKeys.getOrElse(key, key), value) } else { properties.put(options.nodeMetadata.properties.getOrElse(key, key), value) } }) if (options.nodeMetadata.skipNullKeys && containsNull(keys)) { logSkipping("node keys", options.nodeMetadata.nodeKeys.values) None } else { Some(rowMap) } } private def nativeStrategyConsumer(): MappingBiConsumer = new MappingBiConsumer { override def accept(key: String, value: AnyRef): Unit = { if (key.startsWith(Neo4jUtil.RELATIONSHIP_ALIAS.concat("."))) { relMap.get(PROPERTIES).put(key.removeAlias(), value) } else if (key.startsWith(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS.concat("."))) { if (options.relationshipMetadata.source.nodeKeys.contains(key)) { sourceNodeMap.get(KEYS).put(key.removeAlias(), value) } else { sourceNodeMap.get(PROPERTIES).put(key.removeAlias(), value) } } else if (key.startsWith(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS.concat("."))) { if (options.relationshipMetadata.target.nodeKeys.contains(key)) { targetNodeMap.get(KEYS).put(key.removeAlias(), value) } else { targetNodeMap.get(PROPERTIES).put(key.removeAlias(), value) } } } } private def addToNodeMap( nodeMap: util.Map[String, util.Map[String, AnyRef]], nodeMetadata: Neo4jNodeMetadata, key: String, value: AnyRef ): Unit = { if (nodeMetadata.nodeKeys.contains(key)) { nodeMap.get(KEYS).put(nodeMetadata.nodeKeys.getOrElse(key, key), value) } if (nodeMetadata.properties.contains(key)) { nodeMap.get(PROPERTIES).put(nodeMetadata.properties.getOrElse(key, key), value) } } private def keysStrategyConsumer(): MappingBiConsumer = new MappingBiConsumer { override def accept(key: String, value: AnyRef): Unit = { val source = options.relationshipMetadata.source val target = options.relationshipMetadata.target addToNodeMap(sourceNodeMap, source, key, value) addToNodeMap(targetNodeMap, target, key, value) if (options.relationshipMetadata.relationshipKeys.contains(key)) { relMap.get(KEYS).put(options.relationshipMetadata.relationshipKeys.getOrElse(key, key), value) } else { val propertyKey = options.relationshipMetadata.properties match { case Some(relProperties) => relProperties.get(key) case None => if (!source.includesProperty(key) && !target.includesProperty(key)) { Some(key) } else { None } } propertyKey.foreach(k => relMap.get(PROPERTIES).put(k, value)) } } } override def relationship(row: InternalRow, schema: StructType): Option[java.util.Map[String, AnyRef]] = { val rowMap: java.util.Map[String, AnyRef] = new java.util.HashMap[String, AnyRef] val consumer = options.relationshipMetadata.saveStrategy match { case RelationshipSaveStrategy.NATIVE => nativeStrategyConsumer() case RelationshipSaveStrategy.KEYS => keysStrategyConsumer() } query(row, schema).get.forEach(consumer) if ( options.relationshipMetadata.saveStrategy.equals(RelationshipSaveStrategy.NATIVE) && consumer.relMap.get(PROPERTIES).isEmpty && consumer.sourceNodeMap.get(PROPERTIES).isEmpty && consumer.sourceNodeMap.get(KEYS).isEmpty && consumer.targetNodeMap.get(PROPERTIES).isEmpty && consumer.targetNodeMap.get(KEYS).isEmpty ) { throw new IllegalArgumentException( "NATIVE write strategy requires a schema like: rel.[props], source.[props], target.[props]. " + "All of these columns are empty in the current schema." ) } if (options.relationshipMetadata.skipNullKeys && containsNull(consumer.relMap, KEYS)) { logSkipping("relationship keys", options.relationshipMetadata.relationshipKeys.values) None } else if (options.relationshipMetadata.source.skipNullKeys && containsNull(consumer.sourceNodeMap, KEYS)) { logSkipping("source node keys", options.relationshipMetadata.source.nodeKeys.values) None } else if (options.relationshipMetadata.target.skipNullKeys && containsNull(consumer.targetNodeMap, KEYS)) { logSkipping("target node keys", options.relationshipMetadata.target.nodeKeys.values) None } else { rowMap.put(Neo4jUtil.RELATIONSHIP_ALIAS, consumer.relMap) rowMap.put(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS, consumer.sourceNodeMap) rowMap.put(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS, consumer.targetNodeMap) Some(rowMap) } } override def query(row: InternalRow, schema: StructType): Option[java.util.Map[String, AnyRef]] = { val seq = row.toSeq(schema) Some( schema.indices .flatMap(i => { val field = schema(i) val neo4jValue = dataConverter.convert(seq(i), field.dataType) neo4jValue match { case map: MapValue => map.asMap().asScala.toMap .flattenMap(field.name, options.schemaMetadata.mapGroupDuplicateKeys) .mapValues(value => Values.value(value).asInstanceOf[AnyRef]) .toSeq case _ => Seq((field.name, neo4jValue)) } }) .toMap .asJava ) } // Helper methods private def containsNull(map: java.util.Map[String, Object]): Boolean = { map.containsValue(Values.NULL) } private def containsNull(map: java.util.Map[String, java.util.Map[String, AnyRef]], key: String): Boolean = { map.get(key).containsValue(Values.NULL) } private def logSkipping(keyType: String, keys: Iterable[String]): Unit = { logTrace(s"Skipping row because it contains null value for one of the $keyType: [${keys.mkString(", ")}]") } } class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumns: StructType) extends Neo4jMappingStrategy[Record, InternalRow] { private val dataConverter = Neo4jToSparkDataConverter(options) override def node(record: Record, schema: StructType): InternalRow = { if (requiredColumns.nonEmpty) { query(record, schema) } else { val node = record.get(Neo4jUtil.NODE_ALIAS).asNode() val nodeMap = new util.HashMap[String, Any](node.asMap()) nodeMap.put(Neo4jUtil.INTERNAL_ID_FIELD, node.id()) nodeMap.put(Neo4jUtil.INTERNAL_LABELS_FIELD, node.labels()) mapToInternalRow(nodeMap, schema) } } private def mapToInternalRow(map: util.Map[String, Any], schema: StructType) = InternalRow .fromSeq( schema.map(field => dataConverter.convert(map.get(field.name), field.dataType)) ) private def flatRelNodeMapping(node: Node, alias: String): mutable.Map[String, Any] = { val nodeMap: mutable.Map[String, Any] = node.asMap().asScala .map(t => (s"$alias.${t._1}", t._2)) nodeMap.put( s"<$alias.${ Neo4jUtil.INTERNAL_ID_FIELD .replaceAll("[<|>]", "") }>", node.id() ) nodeMap.put( s"<$alias.${ Neo4jUtil.INTERNAL_LABELS_FIELD .replaceAll("[<|>]", "") }>", node.labels() ) nodeMap } private def mapRelNodeMapping(node: Node, alias: String): Map[String, util.Map[String, String]] = { val nodeMap: util.Map[String, String] = new util.HashMap[String, String](node.asMap(new function.Function[Value, String] { override def apply(t: Value): String = t.toString })) nodeMap.put(Neo4jUtil.INTERNAL_ID_FIELD, Neo4jUtil.mapper.writeValueAsString(node.id())) nodeMap.put(Neo4jUtil.INTERNAL_LABELS_FIELD, Neo4jUtil.mapper.writeValueAsString(node.labels())) Map(s"<$alias>" -> nodeMap) } override def relationship(record: Record, schema: StructType): InternalRow = { if (requiredColumns.nonEmpty) { query(record, schema) } else { val rel = record.get(Neo4jUtil.RELATIONSHIP_ALIAS).asRelationship() val relMap = new util.HashMap[String, Any](rel.asMap()) .asScala .map(t => (s"rel.${t._1}", t._2)) .asJava relMap.put(Neo4jUtil.INTERNAL_REL_ID_FIELD, rel.id()) relMap.put(Neo4jUtil.INTERNAL_REL_TYPE_FIELD, rel.`type`()) val source = record.get(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS).asNode() val target = record.get(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS).asNode() val (sourceMap, targetMap) = if (options.relationshipMetadata.nodeMap) { ( mapRelNodeMapping(source, Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS), mapRelNodeMapping(target, Neo4jUtil.RELATIONSHIP_TARGET_ALIAS) ) } else { ( flatRelNodeMapping(source, Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS), flatRelNodeMapping(target, Neo4jUtil.RELATIONSHIP_TARGET_ALIAS) ) } relMap.putAll(sourceMap.toMap.asJava) relMap.putAll(targetMap.toMap.asJava) mapToInternalRow(relMap, schema) } } override def query(elem: Record, schema: StructType): InternalRow = mapToInternalRow( elem.asMap(new function.Function[Value, Any] { override def apply(t: Value): Any = t.asObject() }), schema ) } abstract class Neo4jMappingStrategy[IN, OUT] extends Serializable { def node(elem: IN, schema: StructType): OUT def relationship(elem: IN, schema: StructType): OUT def query(elem: IN, schema: StructType): OUT } class MappingService[IN, OUT](private val strategy: Neo4jMappingStrategy[IN, OUT], private val options: Neo4jOptions) extends Serializable { def convert(record: IN, schema: StructType): OUT = options.query.queryType match { case QueryType.LABELS => strategy.node(record, schema) case QueryType.RELATIONSHIP => strategy.relationship(record, schema) case QueryType.QUERY => strategy.query(record, schema) case QueryType.GDS => strategy.query(record, schema) } } object Neo4jWriteMappingStrategy { val KEYS = "keys" val PROPERTIES = "properties" } abstract private class MappingBiConsumer extends BiConsumer[String, AnyRef] { val relMap = new util.HashMap[String, util.Map[String, AnyRef]]() val sourceNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]() val targetNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]() relMap.put(KEYS, new util.HashMap[String, AnyRef]()) relMap.put(PROPERTIES, new util.HashMap[String, AnyRef]()) sourceNodeMap.put(PROPERTIES, new util.HashMap[String, AnyRef]()) sourceNodeMap.put(KEYS, new util.HashMap[String, AnyRef]()) targetNodeMap.put(PROPERTIES, new util.HashMap[String, AnyRef]()) targetNodeMap.put(KEYS, new util.HashMap[String, AnyRef]()) } ================================================ FILE: common/src/main/scala/org/neo4j/spark/service/Neo4jQueryService.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.commons.lang3.StringUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connector.expressions.SortDirection import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate._ import org.apache.spark.sql.sources.And import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.Or import org.neo4j.caniuse.Neo4j import org.neo4j.cypherdsl.core._ import org.neo4j.cypherdsl.core.renderer.Renderer import org.neo4j.spark.cypher.Cypher5Renderer import org.neo4j.spark.cypher.CypherVersionSelector.selectCypherVersionClause import org.neo4j.spark.util.Neo4jImplicits._ import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.NodeSaveMode import org.neo4j.spark.util.QueryType import scala.collection.JavaConverters._ class Neo4jQueryWriteStrategy(private val neo4j: Neo4j, private val saveMode: SaveMode) extends Neo4jQueryStrategy { override def createStatementForQuery(options: Neo4jOptions): String = s"""WITH ${"$"}scriptResult AS ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT} |UNWIND ${"$"}events AS ${Neo4jQueryStrategy.VARIABLE_EVENT} |${options.query.value} |""".stripMargin private def createPropsList(props: Map[String, String], prefix: String): String = { props .map(key => { s"${key._2.quote()}: ${Neo4jQueryStrategy.VARIABLE_EVENT}.$prefix.${key._2.quote()}" }).mkString(", ") } private def keywordFromSaveMode(saveMode: Any): String = { saveMode match { case NodeSaveMode.Overwrite | SaveMode.Overwrite => "MERGE" case NodeSaveMode.ErrorIfExists | SaveMode.ErrorIfExists | SaveMode.Append | NodeSaveMode.Append => "CREATE" case NodeSaveMode.Match => "MATCH" case _ => throw new UnsupportedOperationException(s"SaveMode $saveMode not supported") } } private def createQueryPart(keyword: String, labels: String, keys: String, alias: String): String = { val setStatement = if (!keyword.equals("MATCH")) s" SET $alias += ${Neo4jQueryStrategy.VARIABLE_EVENT}.$alias.${Neo4jWriteMappingStrategy.PROPERTIES}" else "" s"""$keyword ($alias${if (labels.isEmpty) "" else s":$labels"} ${ if (keys.isEmpty) "" else s"{$keys}" })$setStatement""".stripMargin } override def createStatementForRelationships(options: Neo4jOptions): String = { val relationshipKeyword = keywordFromSaveMode(saveMode) val sourceKeyword = keywordFromSaveMode(options.relationshipMetadata.sourceSaveMode) val targetKeyword = keywordFromSaveMode(options.relationshipMetadata.targetSaveMode) val relationship = options.relationshipMetadata.relationshipType.quote() val sourceLabels = options.relationshipMetadata.source.labels .map(_.quote()) .mkString(":") val targetLabels = options.relationshipMetadata.target.labels .map(_.quote()) .mkString(":") val sourceKeys = createPropsList( options.relationshipMetadata.source.nodeKeys, s"source.${Neo4jWriteMappingStrategy.KEYS}" ) val targetKeys = createPropsList( options.relationshipMetadata.target.nodeKeys, s"target.${Neo4jWriteMappingStrategy.KEYS}" ) val sourceQueryPart = createQueryPart(sourceKeyword, sourceLabels, sourceKeys, Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS) val targetQueryPart = createQueryPart(targetKeyword, targetLabels, targetKeys, Neo4jUtil.RELATIONSHIP_TARGET_ALIAS) val withQueryPart = if (sourceKeyword != "MATCH" && targetKeyword == "MATCH") "\nWITH source, event" else { "" } val relKeys = if (options.relationshipMetadata.relationshipKeys.nonEmpty) { options.relationshipMetadata.relationshipKeys .map(t => s"${t._2}: ${Neo4jQueryStrategy.VARIABLE_EVENT}.${Neo4jUtil.RELATIONSHIP_ALIAS}.${Neo4jWriteMappingStrategy.KEYS}.${t._1}" ) .mkString("{", ", ", "}") } else { "" } s"""${selectCypherVersionClause(neo4j)}UNWIND ${"$"}events AS ${Neo4jQueryStrategy.VARIABLE_EVENT} |$sourceQueryPart$withQueryPart |$targetQueryPart |$relationshipKeyword (${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS})-[${Neo4jUtil.RELATIONSHIP_ALIAS}:$relationship$relKeys]->(${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}) |SET ${Neo4jUtil.RELATIONSHIP_ALIAS} += ${Neo4jQueryStrategy.VARIABLE_EVENT}.${Neo4jUtil.RELATIONSHIP_ALIAS}.${Neo4jWriteMappingStrategy.PROPERTIES} |""".stripMargin } override def createStatementForNodes(options: Neo4jOptions): String = { val keyword = keywordFromSaveMode(saveMode) val labels = options.nodeMetadata.labels .map(_.quote()) .mkString(":") val keys = createPropsList( options.nodeMetadata.nodeKeys, Neo4jWriteMappingStrategy.KEYS ) s"""${selectCypherVersionClause(neo4j)}UNWIND ${"$"}events AS ${Neo4jQueryStrategy.VARIABLE_EVENT} |$keyword (node${if (labels.isEmpty) "" else s":$labels"} ${if (keys.isEmpty) "" else s"{$keys}"}) |SET node += ${Neo4jQueryStrategy.VARIABLE_EVENT}.${Neo4jWriteMappingStrategy.PROPERTIES} |""".stripMargin } override def createStatementForGDS(options: Neo4jOptions): String = throw new UnsupportedOperationException("Write operations with GDS are currently not supported") } class Neo4jQueryReadStrategy( neo4j: Neo4j, filters: Array[Filter] = Array.empty[Filter], partitionPagination: PartitionPagination = PartitionPagination.EMPTY, requiredColumns: Seq[String] = Seq.empty, aggregateColumns: Array[AggregateFunc] = Array.empty, jobId: String = "" ) extends Neo4jQueryStrategy with Logging { private val renderer: Renderer = new Cypher5Renderer(neo4j) private val hasSkipLimit: Boolean = partitionPagination.skip != -1 && partitionPagination.topN.limit != -1 override def createStatementForQuery(options: Neo4jOptions): String = { if (partitionPagination.topN.orders.nonEmpty) { logWarning( s"""Top N push-down optimizations with aggregations are not supported for custom queries. |\tThese aggregations are going to be ignored. |\tPlease specify the aggregations in the custom query directly""".stripMargin ) } val limitedQuery = if (hasSkipLimit) { s"${options.query.value} SKIP ${partitionPagination.skip} LIMIT ${partitionPagination.topN.limit}" } else { s"${options.query.value}" } s"WITH $$scriptResult AS scriptResult $limitedQuery" } override def createStatementForRelationships(options: Neo4jOptions): String = { val sourceNode = createNode(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS, options.relationshipMetadata.source.labels) val targetNode = createNode(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS, options.relationshipMetadata.target.labels) val relationship = sourceNode.relationshipTo(targetNode, options.relationshipMetadata.relationshipType) .named(Neo4jUtil.RELATIONSHIP_ALIAS) val matchQuery: StatementBuilder.OngoingReadingWithoutWhere = filterRelationship(sourceNode, targetNode, relationship) val returnExpressions: Seq[Expression] = buildReturnExpression(sourceNode, targetNode, relationship) val stmt = if (aggregateColumns.isEmpty) { val query = matchQuery.returning(returnExpressions: _*) buildStatement(options, query, relationship) } else { buildStatementAggregation(options, matchQuery, relationship, returnExpressions) } renderer.render(stmt) } private def convertSort(entity: PropertyContainer, order: SortOrder): SortItem = { val sortExpression = order.expression().describe() val container: Option[PropertyContainer] = entity match { case relationship: Relationship => if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS}.")) { Some(relationship.getLeft) } else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}.")) { Some(relationship.getRight) } else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_ALIAS}.")) { Some(relationship) } else { None } case _ => Some(entity) } val direction = if (order.direction() == SortDirection.ASCENDING) SortItem.Direction.ASC else SortItem.Direction.DESC Cypher.sort( container .map(_.property(sortExpression.removeAlias())) .getOrElse(Cypher.name(sortExpression.unquote())), direction ) } private def buildReturnExpression(sourceNode: Node, targetNode: Node, relationship: Relationship): Seq[Expression] = { if (requiredColumns.isEmpty) { Seq( relationship.getRequiredSymbolicName, sourceNode.as(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS), targetNode.as(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS) ) } else { requiredColumns.map(column => { val splatColumn = column.split('.') val entityName = splatColumn.head val entity = if (entityName.contains(Neo4jUtil.RELATIONSHIP_ALIAS)) { relationship } else if (entityName.contains(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS)) { sourceNode } else if (entityName.contains(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS)) { targetNode } else { null } if (entity != null && splatColumn.length == 1) { entity match { case n: Node => n.as(entityName.quote()) case r: Relationship => r.getRequiredSymbolicName } } else { getCorrectProperty(column, entity) } }) } } private def buildStatementAggregation( options: Neo4jOptions, query: StatementBuilder.OngoingReadingWithoutWhere, entity: PropertyContainer, fields: Seq[Expression] ): Statement = { val ret = if (hasSkipLimit) { val id = entity match { case node: Node => Functions.id(node) case rel: Relationship => Functions.id(rel) } query .`with`(entity) // Spark does not push down limits/top N when aggregation is involved .orderBy(id) .skip(partitionPagination.skip) .limit(partitionPagination.topN.limit) .returning(fields: _*) } else { val orderByProp = options.streamingOrderBy if (StringUtils.isBlank(orderByProp)) { query.returning(fields: _*) } else { query .`with`(entity) .orderBy(entity.property(orderByProp)) .ascending() .returning(fields: _*) } } ret.build() } private def buildStatement( options: Neo4jOptions, returning: StatementBuilder.TerminalExposesSkip with StatementBuilder.TerminalExposesLimit with StatementBuilder.TerminalExposesOrderBy with StatementBuilder.BuildableStatement[_], entity: PropertyContainer = null ): Statement = { def addSkipLimit(ret: StatementBuilder.TerminalExposesSkip with StatementBuilder.TerminalExposesLimit with StatementBuilder.BuildableStatement[_]) = { if (partitionPagination.skip == 0) { ret.limit(partitionPagination.topN.limit) } else { ret.skip(partitionPagination.skip) .limit(partitionPagination.topN.limit) } } val ret = if (entity == null) { if (hasSkipLimit) addSkipLimit(returning) else returning } else { if (hasSkipLimit) { if (options.partitions == 1 || partitionPagination.topN.orders.nonEmpty) { addSkipLimit(returning.orderBy(partitionPagination.topN.orders.map(order => convertSort(entity, order)): _*)) } else { val id = entity match { case node: Node => Functions.id(node) case rel: Relationship => Functions.id(rel) } addSkipLimit(returning.orderBy(id)) } } else { val orderByProp = options.streamingOrderBy if (StringUtils.isBlank(orderByProp)) returning else returning.orderBy(entity.property(orderByProp)) } } ret.build() } private def filterRelationship(sourceNode: Node, targetNode: Node, relationship: Relationship) = { val matchQuery = Cypher.`match`(sourceNode).`match`(targetNode).`match`(relationship) def getContainer(filter: Filter): PropertyContainer = { if (filter.isAttribute(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS)) { sourceNode } else if (filter.isAttribute(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS)) { targetNode } else if (filter.isAttribute(Neo4jUtil.RELATIONSHIP_ALIAS)) { relationship } else { throw new IllegalArgumentException(s"Attribute '${filter.getAttribute.get}' is not valid") } } if (filters.nonEmpty) { def mapFilter(filter: Filter): Condition = { filter match { case and: And => mapFilter(and.left).and(mapFilter(and.right)) case or: Or => mapFilter(or.left).or(mapFilter(or.right)) case filter: Filter => Neo4jUtil.mapSparkFiltersToCypher(filter, getContainer(filter), filter.getAttributeWithoutEntityName) } } val cypherFilters = filters.map(mapFilter) assembleConditionQuery(matchQuery, cypherFilters) } matchQuery } private def getCorrectProperty(column: String, entity: PropertyContainer): Expression = { def propertyOrSymbolicName(col: String) = { if (entity != null) entity.property(col) else Cypher.name(col) } column match { case Neo4jUtil.INTERNAL_ID_FIELD => Functions.id(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_ID_FIELD) case Neo4jUtil.INTERNAL_REL_ID_FIELD => Functions.id(entity.asInstanceOf[Relationship]).as(Neo4jUtil.INTERNAL_REL_ID_FIELD) case Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD => Functions.id(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD) case Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD => Functions.id(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD) case Neo4jUtil.INTERNAL_REL_TYPE_FIELD => Functions.`type`(entity.asInstanceOf[Relationship]).as(Neo4jUtil.INTERNAL_REL_TYPE_FIELD) case Neo4jUtil.INTERNAL_LABELS_FIELD => Functions.labels(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_LABELS_FIELD) case Neo4jUtil.INTERNAL_REL_SOURCE_LABELS_FIELD => Functions.labels(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_REL_SOURCE_LABELS_FIELD) case Neo4jUtil.INTERNAL_REL_TARGET_LABELS_FIELD => Functions.labels(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_REL_TARGET_LABELS_FIELD) case "*" => Asterisk.INSTANCE case name => { val cleanedName = name.removeAlias() aggregateColumns.find(_.toString == name) .map { case count: Count => { val col = count.column().describe().unquote().removeAlias() val prop = propertyOrSymbolicName(col) if (count.isDistinct) { Functions.countDistinct(prop).as(name) } else { Functions.count(prop).as(name) } } case countStar: CountStar => Functions.count(Asterisk.INSTANCE).as(name) case max: Max => val col = max.column().describe().unquote().removeAlias() val prop = propertyOrSymbolicName(col) Functions.max(prop).as(name) case min: Min => val col = min.column().describe().unquote().removeAlias() val prop = propertyOrSymbolicName(col) Functions.min(prop).as(name) case sum: Sum => { val col = sum.column().describe().unquote().removeAlias() val prop = propertyOrSymbolicName(col) if (sum.isDistinct) { Functions.sumDistinct(prop).as(name) } else { Functions.sum(prop).as(name) } } } .getOrElse(propertyOrSymbolicName(cleanedName).as(name)) .asInstanceOf[Expression] } } } override def createStatementForNodes(options: Neo4jOptions): String = { val node = createNode(Neo4jUtil.NODE_ALIAS, options.nodeMetadata.labels) val matchQuery = filterNode(node) val expressions = requiredColumns.map(column => getCorrectProperty(column, node)) val stmt = if (aggregateColumns.nonEmpty) { buildStatementAggregation(options, matchQuery, node, expressions) } else { val ret = if (requiredColumns.isEmpty) { matchQuery.returning(node) } else { matchQuery.returning(expressions: _*) } buildStatement(options, ret, node) } renderer.render(stmt) } private def filterNode(node: Node) = { val matchQuery = Cypher.`match`(node) if (filters.nonEmpty) { def mapFilter(filter: Filter): Condition = { filter match { case and: And => mapFilter(and.left).and(mapFilter(and.right)) case or: Or => mapFilter(or.left).or(mapFilter(or.right)) case filter: Filter => Neo4jUtil.mapSparkFiltersToCypher(filter, node) } } val cypherFilters = filters.map(mapFilter) assembleConditionQuery(matchQuery, cypherFilters) } matchQuery } def createStatementForNodeCount(options: Neo4jOptions): String = { val node = createNode(Neo4jUtil.NODE_ALIAS, options.nodeMetadata.labels) val matchQuery = filterNode(node) renderer.render(buildStatement(options, matchQuery.returning(Functions.count(node).as("count")))) } def createStatementForRelationshipCount(options: Neo4jOptions): String = { val sourceNode = createNode(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS, options.relationshipMetadata.source.labels) val targetNode = createNode(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS, options.relationshipMetadata.target.labels) val relationship = sourceNode.relationshipTo(targetNode, options.relationshipMetadata.relationshipType) .named(Neo4jUtil.RELATIONSHIP_ALIAS) val matchQuery: StatementBuilder.OngoingReadingWithoutWhere = filterRelationship(sourceNode, targetNode, relationship) renderer.render(buildStatement(options, matchQuery.returning(Functions.count(sourceNode).as("count")))) } private def assembleConditionQuery( matchQuery: StatementBuilder.OngoingReadingWithoutWhere, filters: Array[Condition] ): StatementBuilder.OngoingReadingWithWhere = { matchQuery.where( filters.fold(Conditions.noCondition()) { (a, b) => a.and(b) } ) } private def createNode(name: String, labels: Seq[String]) = { val primaryLabel = labels.head val otherLabels = labels.tail if (labels.isEmpty) { Cypher.anyNode(name) } else { Cypher.node(primaryLabel, otherLabels.asJava).named(name) } } override def createStatementForGDS(options: Neo4jOptions): String = { val retCols = requiredColumns.map(column => getCorrectProperty(column, null)) // we need it in order to parse the field YIELD by the GDS procedure... val (yieldFields, args) = Neo4jUtil.callSchemaService( neo4j, options, jobId, filters, { ss => (ss.struct().fieldNames, ss.inputForGDSProc(options.query.value)) } ) val cypherParams = args .filter(t => { if (!t._2) { true } else { options.gdsMetadata.parameters.containsKey(t._1) } }) .map(_._1) .map(Cypher.parameter) val statement = Cypher.call(options.query.value) .withArgs(cypherParams: _*) .`yield`(yieldFields: _*) .returning(retCols: _*) .build() renderer.render(statement) } } object Neo4jQueryStrategy { val VARIABLE_EVENT = "event" val VARIABLE_EVENTS = "events" val VARIABLE_SCRIPT_RESULT = "scriptResult" val VARIABLE_STREAM = "stream" } abstract class Neo4jQueryStrategy { def createStatementForQuery(options: Neo4jOptions): String def createStatementForRelationships(options: Neo4jOptions): String def createStatementForNodes(options: Neo4jOptions): String def createStatementForGDS(options: Neo4jOptions): String } class Neo4jQueryService(private val options: Neo4jOptions, val strategy: Neo4jQueryStrategy) extends Serializable { def createQuery(): String = options.query.queryType match { case QueryType.LABELS => strategy.createStatementForNodes(options) case QueryType.RELATIONSHIP => strategy.createStatementForRelationships(options) case QueryType.QUERY => strategy.createStatementForQuery(options) case QueryType.GDS => strategy.createStatementForGDS(options) case _ => throw new UnsupportedOperationException( s"""Query Type not supported. |You provided ${options.query.queryType}, |supported types: ${QueryType.values.mkString(",")}""".stripMargin ) } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/service/SchemaService.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.driver.Record import org.neo4j.driver.Session import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.Value import org.neo4j.driver.Values import org.neo4j.driver.exceptions.ClientException import org.neo4j.driver.summary import org.neo4j.spark.config.TopN import org.neo4j.spark.converter.CypherToSparkTypeConverter import org.neo4j.spark.converter.SparkToCypherTypeConverter import org.neo4j.spark.cypher.CypherVersionSelector.selectCypherVersionClause import org.neo4j.spark.service.SchemaService.normalizedClassName import org.neo4j.spark.service.SchemaService.normalizedClassNameFromGraphEntity import org.neo4j.spark.util.Neo4jImplicits.CypherImplicits import org.neo4j.spark.util.Neo4jImplicits.ValueImplicits import org.neo4j.spark.util.OptimizationType import org.neo4j.spark.util._ import java.util import java.util.Collections import java.util.function import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer object PartitionPagination { val EMPTY: PartitionPagination = PartitionPagination(0, -1, TopN(-1)) } case class PartitionPagination(partitionNumber: Int, skip: Long, topN: TopN) class SchemaService( private val neo4j: Neo4j, private val options: Neo4jOptions, private val driverCache: DriverCache, private val filters: Array[Filter] = Array.empty ) extends AutoCloseable with Logging { private val queryReadStrategy = new Neo4jQueryReadStrategy(neo4j, filters) private val session: Session = driverCache.getOrCreate().session(options.session.toNeo4jSession()) private val sessionTransactionConfig = options.toNeo4jTransactionConfig private val cypherToSparkTypeConverter = CypherToSparkTypeConverter(options) private val sparkToCypherTypeConverter = SparkToCypherTypeConverter(options) private def structForNode(labels: Seq[String] = options.nodeMetadata.labels) = { val structFields: mutable.Buffer[StructField] = (try { val query = s"""${selectCypherVersionClause(neo4j)}CALL apoc.meta.nodeTypeProperties($$config) |YIELD propertyName, propertyTypes |WITH DISTINCT propertyName, propertyTypes |WITH propertyName, collect(propertyTypes) AS propertyTypes |RETURN propertyName, reduce(acc = [], elem IN propertyTypes | acc + elem) AS propertyTypes |""".stripMargin val apocConfig = options.apocConfig.procedureConfigMap .getOrElse("apoc.meta.nodeTypeProperties", Map.empty[String, AnyRef]) .asInstanceOf[Map[String, AnyRef]] ++ Map[String, AnyRef]("includeLabels" -> labels.asJava) retrieveSchemaFromApoc(query, Collections.singletonMap("config", apocConfig.asJava)) } catch { case e: ClientException => logResolutionChange("Switching to query schema resolution", e) // TODO get back to Cypher DSL when rand function will be available val query = s"""${selectCypherVersionClause(neo4j)}MATCH (${Neo4jUtil.NODE_ALIAS}:${labels.map(_.quote()).mkString(":")}) |RETURN ${Neo4jUtil.NODE_ALIAS} |ORDER BY rand() |LIMIT ${options.schemaMetadata.flattenLimit} |""".stripMargin val params = Collections.emptyMap[String, AnyRef]() retrieveSchema(query, params, { record => record.get(Neo4jUtil.NODE_ALIAS).asNode.asMap.asScala.toMap }) }) .sortBy(t => t.name) structFields += StructField( Neo4jUtil.INTERNAL_LABELS_FIELD, DataTypes.createArrayType(DataTypes.StringType), nullable = true ) structFields += StructField(Neo4jUtil.INTERNAL_ID_FIELD, DataTypes.LongType, nullable = false) StructType(structFields.reverse.toSeq) } private def retrieveSchemaFromApoc( query: String, params: java.util.Map[String, AnyRef] ): mutable.Buffer[StructField] = { val fields = session.run(query, params, sessionTransactionConfig) .list .asScala .filter(record => !record.get("propertyName").isNull && !record.get("propertyName").isEmpty) .map(record => { val fieldTypesList = record.get("propertyTypes") .asList(new function.Function[Value, String]() { override def apply(v: Value): String = v.asString() }) .asScala val fieldType: String = if (fieldTypesList.size > 1) { log.warn( s""" |The field ${record.get("propertyName")} has different types: $fieldTypesList |Every value will be casted to string. |""".stripMargin ) "String" } else { fieldTypesList.head } StructField(record.get("propertyName").asString, cypherToSparkTypeConverter.convert(fieldType)) }) if (fields.isEmpty) { throw new ClientException("Unable to compute the resulting schema from APOC") } fields } private def retrieveSchema( query: String, params: java.util.Map[String, AnyRef], extractFunction: Record => Map[String, AnyRef] ): mutable.Buffer[StructField] = { session.run(query, params, sessionTransactionConfig).list.asScala .flatMap(extractFunction) .groupBy(_._1) .mapValues(_.map(_._2)) .map(t => options.schemaMetadata.strategy match { case SchemaStrategy.SAMPLE => { val types = t._2.map(value => { if (options.query.queryType == QueryType.QUERY) { normalizedClassName(value, options) } else { normalizedClassNameFromGraphEntity(value, options) } }).toSet if (types.size > 1) { log.warn( s""" |The field ${t._1} has different types: ${types.toString} |Every value will be casted to string. |""".stripMargin ) StructField(t._1, DataTypes.StringType) } else { val value = t._2.head StructField(t._1, cypherToSparkTypeConverter.convert(types.head, value)) } } case SchemaStrategy.STRING => StructField(t._1, DataTypes.StringType) } ) .toBuffer } private def mapStructField(alias: String, field: StructField): StructField = { val name = field.name match { case Neo4jUtil.INTERNAL_ID_FIELD | Neo4jUtil.INTERNAL_LABELS_FIELD => s"<$alias.${field.name.replaceAll("[<|>]", "")}>" case _ => s"$alias.${field.name}" } StructField(name, field.dataType, field.nullable, field.metadata) } private def structForRelationship() = { val structFields: mutable.Buffer[StructField] = ArrayBuffer( StructField(Neo4jUtil.INTERNAL_REL_ID_FIELD, DataTypes.LongType, false), StructField(Neo4jUtil.INTERNAL_REL_TYPE_FIELD, DataTypes.StringType, false) ) if (options.relationshipMetadata.nodeMap) { structFields += StructField( s"<${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS}>", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType), false ) structFields += StructField( s"<${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}>", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType), false ) } else { structFields ++= structForNode(options.relationshipMetadata.source.labels) .map(field => mapStructField(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS, field)) structFields ++= structForNode(options.relationshipMetadata.target.labels) .map(field => mapStructField(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS, field)) } structFields ++= (try { val query = s"""${selectCypherVersionClause( neo4j )}CALL apoc.meta.relTypeProperties($$config) YIELD sourceNodeLabels, targetNodeLabels, | propertyName, propertyTypes |WITH * |WHERE sourceNodeLabels = $$sourceLabels AND targetNodeLabels = $$targetLabels |RETURN * |""".stripMargin val apocConfig = options.apocConfig.procedureConfigMap .getOrElse("apoc.meta.relTypeProperties", Map.empty[String, AnyRef]) .asInstanceOf[Map[String, AnyRef]] val config = apocConfig ++ Map("includeRels" -> Seq(options.relationshipMetadata.relationshipType).asJava) val params = Map[String, AnyRef]( "config" -> config.asJava, "sourceLabels" -> options.relationshipMetadata.source.labels.asJava, "targetLabels" -> options.relationshipMetadata.target.labels.asJava ) .asJava retrieveSchemaFromApoc(query, params) } catch { case e: ClientException => logResolutionChange("Switching to query schema resolution", e) // TODO get back to Cypher DSL when rand function will be available val query = s"""${selectCypherVersionClause( neo4j )}MATCH (${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS}:${options.relationshipMetadata.source.labels.map( _.quote() ).mkString(":")}) |MATCH (${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}:${options.relationshipMetadata.target.labels.map( _.quote() ).mkString(":")}) |MATCH (${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS})-[${Neo4jUtil.RELATIONSHIP_ALIAS}:${options.relationshipMetadata.relationshipType}]->(${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}) |RETURN ${Neo4jUtil.RELATIONSHIP_ALIAS} |ORDER BY rand() |LIMIT ${options.schemaMetadata.flattenLimit} |""".stripMargin val params = Collections.emptyMap[String, AnyRef]() retrieveSchema( query, params, { record => record.get(Neo4jUtil.RELATIONSHIP_ALIAS).asRelationship.asMap.asScala.toMap } ) }) .map(field => StructField(s"rel.${field.name}", field.dataType, field.nullable, field.metadata)) .sortBy(t => t.name) StructType(structFields.toSeq) } private def structForQuery(): StructType = { val query = queryReadStrategy.createStatementForQuery(options) if (!isValidQuery(query, summary.QueryType.READ_ONLY)) { return new StructType() } val params = Map[String, AnyRef]( Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT -> Collections.emptyList(), Neo4jQueryStrategy.VARIABLE_STREAM -> Collections.emptyMap() ) .asJava val randLimitedQueryForSchema = s""" |$query |ORDER BY rand() |LIMIT ${options.schemaMetadata.flattenLimit} |""".stripMargin val randCallLimitedQueryForSchema = s""" |CALL { | $query |} RETURN * |ORDER BY rand() |LIMIT ${options.schemaMetadata.flattenLimit} |""".stripMargin val limitedQuery = if (isValidQuery(randLimitedQueryForSchema)) randLimitedQueryForSchema else randCallLimitedQueryForSchema val structFields = retrieveSchema(limitedQuery, params, { record => record.asMap.asScala.toMap }) val columns = getReturnedColumns(query) if (columns.isEmpty && structFields.isEmpty) { throw new ClientException( "Unable to compute the resulting schema; this may mean your result set is empty or your version of Neo4j does not permit schema inference for empty sets" ) } if (columns.isEmpty) { return StructType(structFields.toSeq) } val sortedStructFields = if (structFields.isEmpty) { // df: we arrived here because there are no data returned by the query // so we want to return an empty dataset which schema is equals to the columns // specified by the RETURN statement columns.map(StructField(_, DataTypes.StringType)) } else { try { columns.map(column => structFields.find(_.name.quote() == column.quote()).orNull).filter(_ != null) } catch { case _: Throwable => structFields.toArray } } StructType(sortedStructFields) } private def structForGDS() = { val query = s""" |${selectCypherVersionClause(neo4j)}CALL gds.list() YIELD name, signature, type |WHERE name = $$procName AND type = 'procedure' |WITH split(signature, ') :: (')[1] AS fields |WITH substring(fields, 0, size(fields) - 1) AS fields |WITH split(fields, ',') AS fields |WITH [field IN fields | split(field, ' :: ')] AS fields |UNWIND fields AS field |WITH field |RETURN * |""".stripMargin val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> options.query.value).asJava val fields = session.run(query, map, sessionTransactionConfig).list.asScala .map(r => r.get("field").asList((t: Value) => t.asString()).asScala) .map(r => ( r.head.trim, r(1).replaceAll("\\?", "") match { case "STRING" => ("String", null) case "INTEGER" => ("Long", null) case "FLOAT" | "NUMBER" => ("Double", null) case "DATETIME" => ("DateTime", null) case "BOOLEAN" => ("Boolean", null) case "LOCALTIME" => ("LocalTime", null) case "LIST OF INTEGER" | "LIST" | "LIST" => ("LongArray", null) case "LIST OF FLOAT" | "LIST" | "LIST" => ("DoubleArray", null) case "LIST OF STRING" | "LIST" | "LIST" => ("StringArray", null) case "MAP" => logWarning( s""" |For procedure ${options.query.value} |Neo4j return type MAP? of field ${r.head.trim} not fully supported. |We'll coerce it to a Map |""".stripMargin ) ("Map", Map("key" -> "").asJava) // dummy value case "LIST OF MAP" | "LIST" | "LIST" => logWarning( s""" |For procedure ${options.query.value} |Neo4j return type LIST? OF MAP? of field ${r.head.trim} not fully supported. |We'll coerce it to a [Map] |""".stripMargin ) ("MapArray", Seq(Map("key" -> "").asJava).asJava) // dummy value case "PATH" => ("Path", null) case _ => throw new IllegalArgumentException(s"Neo4j type ${r(1)} not supported") } ) ) .map(r => StructField(r._1, cypherToSparkTypeConverter.convert(r._2._1, r._2._2))) .toSeq StructType(fields) } def inputForGDSProc(procName: String): Seq[(String, Boolean)] = { val query = """ |WITH $procName AS procName |CALL gds.list() YIELD name, signature, type |WHERE name = procName AND type = 'procedure' |WITH replace(signature, procName + '(', '') AS signature |WITH split(signature, ') :: (')[0] AS fields |WITH substring(fields, 0, size(fields) - 1) AS fields |WITH split(fields, ',') AS fields |WITH [field IN fields | split(field, ' :: ')] AS fields |UNWIND fields AS field |WITH trim(split(field[0], ' = ')[0]) AS fieldName, field[0] contains ' = ' AS optional |RETURN * |""".stripMargin val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava session.run(query, map, sessionTransactionConfig) .list .asScala .map(r => (r.get("fieldName").asString(), r.get("optional").asBoolean())) .toSeq } private def getReturnedColumns(query: String): Array[String] = session.run("EXPLAIN " + query, sessionTransactionConfig) .keys().asScala.toArray def struct(): StructType = { val struct = options.query.queryType match { case QueryType.LABELS => structForNode() case QueryType.RELATIONSHIP => structForRelationship() case QueryType.QUERY => structForQuery() case QueryType.GDS => structForGDS() } struct } def countForNodeWithQuery(filters: Array[Filter]): Long = { val query = if (filters.isEmpty) { options.nodeMetadata.labels .map(_.quote()) .map(label => s""" |MATCH (:$label) |RETURN count(*) AS count""".stripMargin ) .mkString(" UNION ALL ") } else { queryReadStrategy.createStatementForNodeCount(options) } log.info(s"Executing the following counting query on Neo4j: $query") session.readTransaction( tx => tx.run(query, Values.value(Neo4jUtil.paramsFromFilters(filters).asJava)) .list() .asScala .map(_.get("count")) .map(count => if (count.isNull) 0L else count.asLong()) .min, sessionTransactionConfig ) } def countForRelationshipWithQuery(filters: Array[Filter]): Long = { val query = if (filters.isEmpty) { val sourceQueries = options.relationshipMetadata.source.labels .map(_.quote()) .map(label => s"""MATCH (:$label)-[${Neo4jUtil.RELATIONSHIP_ALIAS}:${options.relationshipMetadata.relationshipType.quote()}]->() |RETURN count(${Neo4jUtil.RELATIONSHIP_ALIAS}) AS count |""".stripMargin ) val targetQueries = options.relationshipMetadata.target.labels .map(_.quote()) .map(label => s"""MATCH ()-[${Neo4jUtil.RELATIONSHIP_ALIAS}:${options.relationshipMetadata.relationshipType.quote()}]->(:$label) |RETURN count(${Neo4jUtil.RELATIONSHIP_ALIAS}) AS count |""".stripMargin ) (sourceQueries ++ targetQueries) .mkString(" UNION ALL ") } else { queryReadStrategy.createStatementForRelationshipCount(options) } log.info(s"Executing the following counting query on Neo4j: $query") session.run(query, sessionTransactionConfig) .list() .asScala .map(_.get("count")) .map(count => if (count.isNull) 0L else count.asLong()) .min } def countForNode(filters: Array[Filter]): Long = try { /* * we try to leverage the count store in order to have the faster response possible * https://neo4j.com/developer/kb/fast-counts-using-the-count-store/ * so in this scenario we have some limitations given the fact that we get the min * for the sequence of counts returned */ if (filters.isEmpty) { val query = "CALL apoc.meta.stats() yield labels RETURN labels" val map = session.run(query, sessionTransactionConfig).single() .asMap() .asScala .get("labels") .getOrElse(Collections.emptyMap()) .asInstanceOf[util.Map[String, Long]].asScala map.filterKeys(k => options.nodeMetadata.labels.contains(k)) .values.min } else { countForNodeWithQuery(filters) } } catch { case e: ClientException => { logResolutionChange("Switching to query count resolution", e) countForNodeWithQuery(filters) } case e: Throwable => logExceptionForCount(e) } def countForRelationship(filters: Array[Filter]): Long = try { if (filters.isEmpty) { val query = "CALL apoc.meta.stats() yield relTypes RETURN relTypes" val map = session.run(query, sessionTransactionConfig).single() .asMap() .asScala .get("relTypes") .getOrElse(Collections.emptyMap()) .asInstanceOf[util.Map[String, Long]] .asScala val minFromSource = options.relationshipMetadata.source.labels .map(_.quote()) .map(label => map.get(s"(:$label)-[:${options.relationshipMetadata.relationshipType}]->()").getOrElse(Long.MaxValue) ) .min val minFromTarget = options.relationshipMetadata.target.labels .map(_.quote()) .map(label => map.get(s"()-[:${options.relationshipMetadata.relationshipType}]->(:$label)").getOrElse(Long.MaxValue) ) .min Math.min(minFromSource, minFromTarget) } else { countForRelationshipWithQuery(filters) } } catch { case e: ClientException => { logResolutionChange("Switching to query count resolution", e) countForRelationshipWithQuery(filters) } case e: Throwable => logExceptionForCount(e) } private def logExceptionForCount(e: Throwable): Long = { log.error("Cannot compute the count because the following exception:", e) -1 } def skipLimitFromPartition(topN: Option[TopN]): Seq[PartitionPagination] = if (options.partitions == 1) { val skipLimit = topN.map(top => PartitionPagination(0, 0, top)).getOrElse(PartitionPagination.EMPTY) Seq(skipLimit) } else { val count: Long = this.count() if (count <= 0) { Seq(PartitionPagination.EMPTY) } else { val partitionSize = Math.ceil(count.toDouble / options.partitions).toLong (0 until options.partitions) .map(index => PartitionPagination(index, index * partitionSize, TopN(partitionSize))) } } def count(filters: Array[Filter] = this.filters): Long = options.query.queryType match { case QueryType.LABELS => countForNode(filters) case QueryType.RELATIONSHIP => countForRelationship(filters) case QueryType.QUERY => countForQuery() } private def countForQuery(): Long = { val queryCount: String = options.queryMetadata.queryCount if (Neo4jUtil.isLong(queryCount)) { queryCount.trim.toLong } else { val query = if (queryCount.nonEmpty) { options.queryMetadata.queryCount } else { s"""CALL { ${options.query.value} } |RETURN count(*) AS count |""".stripMargin } session.run(query, sessionTransactionConfig).single().get("count").asLong() } } def isGdsProcedure(procName: String): Boolean = { val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava session.run( """ |CALL gds.list() YIELD name, type |WHERE name = $procName AND type = 'procedure' |RETURN count(*) = 1 |""".stripMargin, params, sessionTransactionConfig ) .single() .get(0) .asBoolean() } def validateQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): String = try { val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType() if (expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)) { "" } else { s"Invalid query `${cleanQuery(query)}` because the accepted types are [${expectedQueryTypes.mkString(", ")}], but the actual type is $queryType" } } catch { case e: Throwable => s"Query not compiled for the following exception: ${ExceptionUtils.getMessage(e)}" } private def cleanQuery(query: String) = { query .replace( s"WITH {} AS ${Neo4jQueryStrategy.VARIABLE_EVENT}, [] as ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT}", "" ) .replace(s"WITH [] as ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT}", "") .replace(s"WITH {} AS ${Neo4jQueryStrategy.VARIABLE_EVENT}", "") .trim } def validateQueryCount(query: String): String = try { val resultSummary = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume() val queryType = resultSummary.queryType() val plan = resultSummary.plan() val expectedQueryTypes = Set(org.neo4j.driver.summary.QueryType.READ_ONLY, org.neo4j.driver.summary.QueryType.SCHEMA_WRITE) val isReadOnly = expectedQueryTypes.contains(queryType) val hasCountIdentifier = plan.identifiers().asScala.toSet == Set("count") if (isReadOnly && hasCountIdentifier) { "" } else { s"Invalid query `${cleanQuery(query)}` because the expected type should be [${expectedQueryTypes.mkString(", ")}], but the actual type is $queryType" } } catch { case e: Throwable => s"Query count not compiled for the following exception: ${ExceptionUtils.getMessage(e)}" } def isValidQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): Boolean = try { val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType() expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType) } catch { case e: Throwable => { if (log.isDebugEnabled) { log.debug("Query not compiled because of the following exception:", e) } false } } @deprecated("use createEntityConstraint instead") private def createIndexOrConstraint(action: OptimizationType.Value, label: String, props: Seq[String]): Unit = action match { case OptimizationType.NONE => log.info("No optimization type provided") case _ => { try { val quotedLabel = label.quote() val quotedProps = props .map(prop => s"${Neo4jUtil.NODE_ALIAS}.${prop.quote()}") .mkString(", ") val isNeo4j4 = neo4j.getVersion.getMajor == 4 val uniqueFieldName = if (!isNeo4j4) "owningConstraint" else "uniqueness" val dashSeparatedProps = props.mkString("-") val (querySuffix, uniqueCondition) = action match { case OptimizationType.INDEX => ( s"FOR (${Neo4jUtil.NODE_ALIAS}:$quotedLabel) ON ($quotedProps)", if (!isNeo4j4) s"$uniqueFieldName IS NULL" else s"$uniqueFieldName = 'NONUNIQUE'" ) case OptimizationType.NODE_CONSTRAINTS => { val assertType = if (props.size > 1) "NODE KEY" else "UNIQUE" ( s"FOR (${Neo4jUtil.NODE_ALIAS}:$quotedLabel) REQUIRE ($quotedProps) IS $assertType", if (!isNeo4j4) s"$uniqueFieldName IS NOT NULL" else s"$uniqueFieldName = 'UNIQUE'" ) } } val actionName = s"spark_${action.toString}_${label}_$dashSeparatedProps".quote() val queryPrefix = action match { case OptimizationType.INDEX => s"CREATE INDEX $actionName" case OptimizationType.NODE_CONSTRAINTS => s"CREATE CONSTRAINT $actionName" } val queryCheck = s"""SHOW INDEXES YIELD labelsOrTypes, properties, $uniqueFieldName |WHERE labelsOrTypes = ${'$'}labels |AND properties = ${'$'}properties |AND $uniqueCondition |RETURN count(*) > 0 AS isPresent""".stripMargin val params: util.Map[String, AnyRef] = Map( "labels" -> Seq(label).asJava, "properties" -> props.asJava ).asJava.asInstanceOf[util.Map[String, AnyRef]] val isPresent = session.run(queryCheck, params, sessionTransactionConfig) .single() .get("isPresent") .asBoolean() val status = if (isPresent) { "KEPT" } else { val query = s"$queryPrefix $querySuffix" log.info(s"Performing the following schema query: $query") session.run(query, sessionTransactionConfig) "CREATED" } log.info(s"Status for $action named with label $quotedLabel and props $quotedProps is: $status") } catch { case e: Throwable => log.info("Cannot perform the optimization query because of the following exception:", e) } } } private def createEntityConstraint( entityType: String, entityIdentifier: String, constraintsOptimizationType: ConstraintsOptimizationType.Value, keys: Map[String, String] ): Unit = { val constraintType = if (constraintsOptimizationType == ConstraintsOptimizationType.UNIQUE) { "UNIQUE" } else { s"$entityType KEY" } val dashSeparatedProps = keys.values.mkString("-") val constraintName = s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote() val props = keys.values.map(_.quote()).map("e." + _).mkString(", ") val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier) session.writeTransaction( tx => { tx.run( s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType" ) }, sessionTransactionConfig ) } private def createCypherPattern(entityType: String, entityIdentifier: String) = { val asciiRepresentation = entityType match { case "NODE" => s"(e:${entityIdentifier.quote()})" case "RELATIONSHIP" => s"()-[e:${entityIdentifier.quote()}]->()" case _ => throw new IllegalArgumentException(s"$entityType not supported") } asciiRepresentation } private def createEntityTypeConstraint( entityType: String, entityIdentifier: String, properties: Map[String, String], struct: StructType, constraints: Set[SchemaConstraintsOptimizationType.Value] ): Unit = { val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier) session.writeTransaction( tx => { properties .filter(t => struct.exists(f => f.name == t._1)) .map(t => { val field = struct.find(f => f.name == t._1).get (t._2, sparkToCypherTypeConverter.convert(field.dataType), field.nullable) }) .foreach(t => { val prop = t._1.quote() val cypherType = t._2 val isNullable = t._3 if (constraints.contains(SchemaConstraintsOptimizationType.TYPE)) { val typeConstraintName = s"spark_$entityType-TYPE-CONSTRAINT-$entityIdentifier-$prop".quote() tx.run( s"CREATE CONSTRAINT $typeConstraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE e.$prop IS :: $cypherType" ).consume() } if (constraints.contains(SchemaConstraintsOptimizationType.EXISTS)) { if (!isNullable) { val notNullConstraintName = s"spark_$entityType-NOT_NULL-CONSTRAINT-$entityIdentifier-$prop".quote() tx.run( s"CREATE CONSTRAINT $notNullConstraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE e.$prop IS NOT NULL" ).consume() } } }) }, sessionTransactionConfig ) } private def createOptimizationsForNode(struct: StructType): Unit = { val schemaMetadata = options.schemaMetadata.optimization if ( schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE) ) { if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE) { createEntityConstraint( "NODE", options.nodeMetadata.labels.head, schemaMetadata.nodeConstraint, options.nodeMetadata.nodeKeys ) } if (schemaMetadata.schemaConstraints.nonEmpty) { val propsFromStruct: Map[String, String] = struct .map(f => (f.name, f.name)) .toMap val propsFromMeta: Map[String, String] = options.nodeMetadata.nodeKeys ++ options.nodeMetadata.properties createEntityTypeConstraint( "NODE", options.nodeMetadata.labels.head, propsFromStruct ++ propsFromMeta, struct, schemaMetadata.schemaConstraints ) } } else { // TODO old behaviour, remove it in the future options.schemaMetadata.optimizationType match { case OptimizationType.INDEX | OptimizationType.NODE_CONSTRAINTS => { createIndexOrConstraint( options.schemaMetadata.optimizationType, options.nodeMetadata.labels.head, options.nodeMetadata.nodeKeys.values.toSeq ) } case _ => // do nothing } } } private def createOptimizationsForRelationship(struct: StructType): Unit = { val schemaMetadata = options.schemaMetadata.optimization if ( schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.relConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE) ) { if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE) { createEntityConstraint( "NODE", options.relationshipMetadata.source.labels.head, schemaMetadata.nodeConstraint, options.relationshipMetadata.source.nodeKeys ) createEntityConstraint( "NODE", options.relationshipMetadata.target.labels.head, schemaMetadata.nodeConstraint, options.relationshipMetadata.target.nodeKeys ) } if (schemaMetadata.relConstraint != ConstraintsOptimizationType.NONE) { createEntityConstraint( "RELATIONSHIP", options.relationshipMetadata.relationshipType, schemaMetadata.relConstraint, options.relationshipMetadata.relationshipKeys ) } if (schemaMetadata.schemaConstraints.nonEmpty) { val sourceNodeProps: Map[String, String] = options.relationshipMetadata.source.nodeKeys ++ options.relationshipMetadata.source.properties val targetNodeProps: Map[String, String] = options.relationshipMetadata.target.nodeKeys ++ options.relationshipMetadata.target.properties val allNodeProps: Map[String, String] = sourceNodeProps ++ targetNodeProps val relStruct: StructType = StructType(struct.filterNot(f => allNodeProps.contains(f.name))) val propsFromRelStruct: Map[String, String] = relStruct .map(f => (f.name, f.name)) .toMap val propsFromMeta: Map[String, String] = options.relationshipMetadata.relationshipKeys ++ options.relationshipMetadata.properties.getOrElse(Map.empty) createEntityTypeConstraint( "RELATIONSHIP", options.relationshipMetadata.relationshipType, propsFromRelStruct ++ propsFromMeta, struct, schemaMetadata.schemaConstraints ) createEntityTypeConstraint( "NODE", options.relationshipMetadata.source.labels.head, sourceNodeProps, struct, schemaMetadata.schemaConstraints ) createEntityTypeConstraint( "NODE", options.relationshipMetadata.target.labels.head, targetNodeProps, struct, schemaMetadata.schemaConstraints ) } } else { // TODO old behaviour, remove it in the future options.schemaMetadata.optimizationType match { case OptimizationType.INDEX | OptimizationType.NODE_CONSTRAINTS => { createIndexOrConstraint( options.schemaMetadata.optimizationType, options.relationshipMetadata.source.labels.head, options.relationshipMetadata.source.nodeKeys.values.toSeq ) createIndexOrConstraint( options.schemaMetadata.optimizationType, options.relationshipMetadata.target.labels.head, options.relationshipMetadata.target.nodeKeys.values.toSeq ) } case _ => // do nothing } } } def createOptimizations(struct: StructType): Unit = { Validations.validate(ValidateSchemaOptions(options, struct)) options.query.queryType match { case QueryType.LABELS => createOptimizationsForNode(struct) case QueryType.RELATIONSHIP => createOptimizationsForRelationship(struct) case _ => // do nothing } } def execute(queries: Seq[String]): util.List[util.Map[String, AnyRef]] = { val queryMap = queries .map(query => { (session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType(), query) }) .groupBy(_._1) .mapValues(_.map(_._2)) val schemaQueries = queryMap.getOrElse(org.neo4j.driver.summary.QueryType.SCHEMA_WRITE, Seq.empty[String]) schemaQueries.foreach(session.run(_, sessionTransactionConfig)) val others = queryMap .filterKeys(key => key != org.neo4j.driver.summary.QueryType.SCHEMA_WRITE) .values .flatten .toSeq if (others.isEmpty) { Collections.emptyList() } else { session .writeTransaction( new TransactionWork[util.List[java.util.Map[String, AnyRef]]] { override def execute(transaction: Transaction): util.List[util.Map[String, AnyRef]] = { others.size match { case 1 => transaction.run(others.head).list() .asScala .map(_.asMap()) .asJava case _ => { others .slice(0, queries.size - 1) .foreach(transaction.run) val result = transaction.run(others.last).list() .asScala .map(_.asMap()) .asJava result } } } }, sessionTransactionConfig ) } } def lastOffset(): Option[Long] = options.query.queryType match { case QueryType.LABELS => lastOffsetForNode() case QueryType.RELATIONSHIP => lastOffsetForRelationship() case QueryType.QUERY => lastOffsetForQuery() } private def lastOffsetForNode(): Option[Long] = { val label = options.nodeMetadata.labels.head session.run( s"""MATCH (n:$label) |RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin, sessionTransactionConfig ) .single() .get(options.streamingOptions.propertyName) .asOptionalLong() } private def lastOffsetForRelationship(): Option[Long] = { val sourceLabel = options.relationshipMetadata.source.labels.head.quote() val targetLabel = options.relationshipMetadata.target.labels.head.quote() val relType = options.relationshipMetadata.relationshipType.quote() session.run( s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel) |RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin, sessionTransactionConfig ) .single() .get(options.streamingOptions.propertyName) .asOptionalLong() } private def lastOffsetForQuery(): Option[Long] = { session.run(options.streamingOptions.queryOffset, sessionTransactionConfig) .single() .get(0) .asOptionalLong() } private def logResolutionChange(message: String, e: ClientException): Unit = { log.warn(message) if (!e.code().equals("Neo.ClientError.Procedure.ProcedureNotFound")) { log.warn(s"For the following exception", e) } } override def close(): Unit = { Neo4jUtil.closeSafely(session, log) } } object SchemaService { val POINT_TYPE_2D = "point-2d" val POINT_TYPE_3D = "point-3d" val TIME_TYPE_OFFSET = "offset-time" val TIME_TYPE_LOCAL = "local-time" val DURATION_TYPE = "duration" def normalizedClassName(value: AnyRef, options: Neo4jOptions): String = value match { case binary: Array[Byte] => if (options.legacyTypeConversionEnabled) value.getClass.getSimpleName else "ByteArray" case list: java.util.List[_] => "Array" case map: java.util.Map[String, _] => "Map" case null => "String" case _ => value.getClass.getSimpleName } // from nodes and relationships we cannot have maps as properties and elements in lists are the same type // special treatment for ByteArray required (pattern matching on Array != List) def normalizedClassNameFromGraphEntity(value: AnyRef, options: Neo4jOptions): String = value match { case binary: Array[Byte] => if (options.legacyTypeConversionEnabled) value.getClass.getSimpleName else "ByteArray" case list: java.util.List[_] => s"${list.get(0).getClass.getSimpleName}Array" case null => "String" case _ => value.getClass.getSimpleName } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/streaming/BaseStreamingPartitionReader.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.GreaterThan import org.apache.spark.sql.sources.LessThanOrEqual import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.cypherdsl.core.Cypher import org.neo4j.spark.reader.BasePartitionReader import org.neo4j.spark.service.Neo4jQueryStrategy import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.streaming.BaseStreamingPartitionReader.offsetUsagePatterns import org.neo4j.spark.util.Neo4jImplicits._ import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.QueryType._ import org.neo4j.spark.util.StreamingFrom import java.util import java.util.function.Predicate import java.util.regex.Pattern import scala.collection.JavaConverters.mapAsJavaMapConverter class BaseStreamingPartitionReader( private val neo4j: Neo4j, private val options: Neo4jOptions, private val filters: Array[Filter], private val schema: StructType, private val jobId: String, private val partitionSkipLimit: PartitionPagination, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val requiredColumns: StructType, private val aggregateColumns: Array[AggregateFunc] ) extends BasePartitionReader( neo4j, options, filters, schema, jobId, partitionSkipLimit, scriptResult, requiredColumns, aggregateColumns ) { private val streamingPropertyName = Neo4jUtil.getStreamingPropertyName(options) private val streamStart = filters.find(f => f.getAttribute.contains(streamingPropertyName) && f.isInstanceOf[GreaterThan]) private val streamEnd = filters.find(f => f.getAttribute.contains(streamingPropertyName) && f.isInstanceOf[LessThanOrEqual]) logInfo(s"Creating Streaming Partition reader $name") private lazy val values = { val map = new util.HashMap[String, Any](super.queryParameters) val start: Long = streamStart .flatMap(f => f.getValue) .getOrElse(StreamingFrom.ALL.value()) .asInstanceOf[Long] val end: Long = streamEnd .flatMap(f => f.getValue) .get .asInstanceOf[Long] map.put(Neo4jQueryStrategy.VARIABLE_STREAM, Map("offset" -> start, "from" -> start, "to" -> end).asJava) map } override def close(): Unit = { logInfo(s"Closing Partition reader $name ${if (hasError()) "with error " else ""}") super.close() } override protected def query(): String = { options.query.queryType match { case QUERY => val originalQuery = super.query() if (offsetUsagePatterns.exists(_.test(originalQuery))) { logWarning( "Usage of '$stream.offset' is deprecated in favor of '$stream.from' and '$stream.to' parameters which " + "describes the range of changes the micro batch refers to. Please update your queries accordingly." ) } val property = Cypher.name(streamingPropertyName) val stream = Cypher.parameter("stream") // rewrite query for adding $stream.from and $stream.to filters Cypher.callRawCypher(originalQuery) .`with`(Cypher.asterisk()) .where( property.gt(stream.property("from")).and(property.lte(stream.property("to"))) ) .returning(Cypher.asterisk()) .build() .getCypher // we don't need to rewrite the queries for LABELS and RELATIONSHIPS because spark filters already cover our // criteria which are added to the query text in Neo4jQueryService case LABELS => super.query() case RELATIONSHIP => super.query() case GDS => throw new UnsupportedOperationException("GDS strategy is not supported in structured streaming use cases.") } } override protected def queryParameters: util.Map[String, Any] = values } object BaseStreamingPartitionReader { private val offsetUsagePatterns: Seq[Predicate[String]] = Seq( Pattern.compile("\\$stream\\.offset").asPredicate(), Pattern.compile("\\$`stream`\\.offset").asPredicate(), Pattern.compile("\\$stream\\.`offset`").asPredicate(), Pattern.compile("\\$`stream`\\.`offset`").asPredicate() ) } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/DriverCache.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.neo4j.driver.Driver import org.neo4j.spark.util.DriverCache.cache import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger object DriverCache { private val cache: ConcurrentHashMap[Neo4jDriverOptions, (Driver, AtomicInteger)] = new ConcurrentHashMap[Neo4jDriverOptions, (Driver, AtomicInteger)] } class DriverCache(private val options: Neo4jDriverOptions) extends Serializable with AutoCloseable { def getOrCreate(): Driver = { val (driver, counter) = cache.computeIfAbsent(options, (t: Neo4jDriverOptions) => (t.createDriver(), new AtomicInteger(0))) counter.incrementAndGet() driver } def close(): Unit = { val (driver, counter) = cache.get(options) if (counter.decrementAndGet() == 0) { cache.remove(options) Neo4jUtil.closeSafely(driver) } } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/Neo4jImplicits.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import com.fasterxml.jackson.core.JsonParseException import com.fasterxml.jackson.core.JsonParser import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.Literal import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.sources.AlwaysFalse import org.apache.spark.sql.sources.AlwaysTrue import org.apache.spark.sql.sources.And import org.apache.spark.sql.sources.EqualNullSafe import org.apache.spark.sql.sources.EqualTo import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.GreaterThan import org.apache.spark.sql.sources.GreaterThanOrEqual import org.apache.spark.sql.sources.In import org.apache.spark.sql.sources.IsNotNull import org.apache.spark.sql.sources.IsNull import org.apache.spark.sql.sources.LessThan import org.apache.spark.sql.sources.LessThanOrEqual import org.apache.spark.sql.sources.Not import org.apache.spark.sql.sources.Or import org.apache.spark.sql.sources.StringContains import org.apache.spark.sql.sources.StringEndsWith import org.apache.spark.sql.sources.StringStartsWith import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.neo4j.driver.Value import org.neo4j.driver.types.Entity import org.neo4j.driver.types.Node import org.neo4j.driver.types.Relationship import org.neo4j.spark.converter.CypherToSparkTypeConverter import org.neo4j.spark.converter.SparkToNeo4jDataConverter import org.neo4j.spark.service.SchemaService import scala.collection.JavaConverters._ import javax.lang.model.SourceVersion object Neo4jImplicits { implicit class CypherImplicits(str: String) { private def isValidCypherIdentifier() = SourceVersion.isIdentifier(str) && !str.trim.startsWith("$") def quote(): String = if (!isValidCypherIdentifier() && !str.isQuoted()) s"`$str`" else str def unquote(): String = str.replaceAll("`", ""); def isQuoted(): Boolean = str.startsWith("`"); def removeAlias(): String = { val splatString = str.unquote().split('.') if (splatString.size > 1) { splatString.tail.mkString(".") } else { str } } /** * df: we need this to handle scenarios like `WHERE age > 19 and age < 22`, * so we can't basically add a parameter named \$age. * So we base64 encode the value to ensure a unique parameter name */ def toParameterName(value: Any): String = { val attributeValue = if (value == null) { "NULL" } else { value.toString } val base64ed = java.util.Base64.getEncoder.encodeToString(attributeValue.getBytes()) s"${base64ed}_${str.unquote()}".quote() } } implicit class EntityImplicits(entity: Entity) { def toStruct(options: Neo4jOptions): StructType = { val fields = entity.asMap().asScala .groupBy(_._1) .map(t => { val value = t._2.head._2 val cypherType = SchemaService.normalizedClassNameFromGraphEntity(value, options) StructField(t._1, CypherToSparkTypeConverter(options).convert(cypherType)) }) val entityFields = entity match { case _: Node => { Seq( StructField(Neo4jUtil.INTERNAL_ID_FIELD, DataTypes.LongType, nullable = false), StructField( Neo4jUtil.INTERNAL_LABELS_FIELD, DataTypes.createArrayType(DataTypes.StringType), nullable = true ) ) } case _: Relationship => { Seq( StructField(Neo4jUtil.INTERNAL_REL_ID_FIELD, DataTypes.LongType, nullable = false), StructField(Neo4jUtil.INTERNAL_REL_TYPE_FIELD, DataTypes.StringType, nullable = false), StructField(Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD, DataTypes.LongType, nullable = false), StructField(Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD, DataTypes.LongType, nullable = false) ) } } StructType(entityFields ++ fields) } def toMap: java.util.Map[String, Any] = { val entityMap = entity.asMap().asScala val entityFields = entity match { case node: Node => { Map(Neo4jUtil.INTERNAL_ID_FIELD -> node.id(), Neo4jUtil.INTERNAL_LABELS_FIELD -> node.labels()) } case relationship: Relationship => { Map( Neo4jUtil.INTERNAL_REL_ID_FIELD -> relationship.id(), Neo4jUtil.INTERNAL_REL_TYPE_FIELD -> relationship.`type`(), Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD -> relationship.startNodeId(), Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD -> relationship.endNodeId() ) } } (entityFields ++ entityMap).asJava } } implicit class PredicateImplicit(predicate: Predicate) { def toFilter(options: Neo4jOptions): Option[Filter] = { predicate.name() match { case "IS_NULL" => Some(IsNull(predicate.rawAttributeName())) case "IS_NOT_NULL" => Some(IsNotNull(predicate.rawAttributeName())) case "STARTS_WITH" => predicate.rawLiteralValue(options).map(lit => StringStartsWith(predicate.rawAttributeName(), lit.asString())) case "ENDS_WITH" => predicate.rawLiteralValue(options).map(lit => StringEndsWith(predicate.rawAttributeName(), lit.asString())) case "CONTAINS" => predicate.rawLiteralValue(options).map(lit => StringContains(predicate.rawAttributeName(), lit.asString())) case "IN" => Some(In(predicate.rawAttributeName(), predicate.rawLiteralValues(options))) case "=" => predicate.rawLiteralValue(options).map(lit => EqualTo(predicate.rawAttributeName(), lit.asObject())) case "<>" => predicate.rawLiteralValue(options).map(lit => Not(EqualTo(predicate.rawAttributeName(), lit.asObject()))) case "<=>" => predicate.rawLiteralValue(options).map(lit => EqualNullSafe(predicate.rawAttributeName(), lit.asObject())) case "<" => predicate.rawLiteralValue(options).map(lit => LessThan(predicate.rawAttributeName(), lit.asObject())) case "<=" => predicate.rawLiteralValue(options).map(lit => LessThanOrEqual(predicate.rawAttributeName(), lit.asObject())) case ">" => predicate.rawLiteralValue(options).map(lit => GreaterThan(predicate.rawAttributeName(), lit.asObject())) case ">=" => predicate.rawLiteralValue(options).map(lit => GreaterThanOrEqual(predicate.rawAttributeName(), lit.asObject()) ) case "AND" => val andPredicate = predicate.asInstanceOf[filter.And] (andPredicate.left().toFilter(options), andPredicate.right().toFilter(options)) match { case (_, None) => None case (None, _) => None case (Some(left), Some(right)) => Some(And(left, right)) } case "OR" => val andPredicate = predicate.asInstanceOf[filter.Or] (andPredicate.left().toFilter(options), andPredicate.right().toFilter(options)) match { case (_, None) => None case (None, _) => None case (Some(left), Some(right)) => Some(Or(left, right)) } case "NOT" => val notPredicate = predicate.asInstanceOf[filter.Not] notPredicate.child().toFilter(options).map(Not) case "ALWAYS_TRUE" => Some(AlwaysTrue) case "ALWAYS_FALSE" => Some(AlwaysFalse) } } def rawAttributeName(): String = { predicate.references().head.fieldNames().mkString(".") } def rawLiteralValue(options: Neo4jOptions): Option[Value] = { predicate.children() .filter(_.isInstanceOf[Literal[_]]) .map(_.asInstanceOf[Literal[_]]) .headOption .map(literal => SparkToNeo4jDataConverter(options).convert(literal.value(), literal.dataType())) } def rawLiteralValues(options: Neo4jOptions): Array[Any] = { predicate.children() .filter(_.isInstanceOf[Literal[_]]) .map(_.asInstanceOf[Literal[_]]) .map(v => SparkToNeo4jDataConverter(options).convert(v.value(), v.dataType()).asObject()) } } implicit class FilterImplicit(filter: Filter) { def flattenFilters: Array[Filter] = { filter match { case or: Or => Array(or.left.flattenFilters, or.right.flattenFilters).flatten case and: And => Array(and.left.flattenFilters, and.right.flattenFilters).flatten case f: Filter => Array(f) } } def getAttribute: Option[String] = Option(filter match { case eqns: EqualNullSafe => eqns.attribute case eq: EqualTo => eq.attribute case gt: GreaterThan => gt.attribute case gte: GreaterThanOrEqual => gte.attribute case lt: LessThan => lt.attribute case lte: LessThanOrEqual => lte.attribute case in: In => in.attribute case notNull: IsNotNull => notNull.attribute case isNull: IsNull => isNull.attribute case startWith: StringStartsWith => startWith.attribute case endsWith: StringEndsWith => endsWith.attribute case contains: StringContains => contains.attribute case not: Not => not.child.getAttribute.orNull case _ => null }) def getValue: Option[Any] = Option(filter match { case eqns: EqualNullSafe => eqns.value case eq: EqualTo => eq.value case gt: GreaterThan => gt.value case gte: GreaterThanOrEqual => gte.value case lt: LessThan => lt.value case lte: LessThanOrEqual => lte.value case in: In => in.values case startWith: StringStartsWith => startWith.value case endsWith: StringEndsWith => endsWith.value case contains: StringContains => contains.value case not: Not => not.child.getValue.orNull case _ => null }) def isAttribute(entityType: String): Boolean = { getAttribute.exists(_.contains(s"$entityType.")) } def getAttributeWithoutEntityName: Option[String] = filter.getAttribute.map(_.unquote().split('.').tail.mkString(".")) /** * df: we are not handling AND/OR because they are not actually filters * and have a different internal structure. Before calling this function on the filters * it's highly suggested FilterImplicit::flattenFilter() which returns a collection * of filters, including the one contained in the ANDs/ORs objects. */ def getAttributeAndValue: Seq[Any] = { filter match { case f: EqualNullSafe => Seq(f.attribute.toParameterName(f.value), f.value) case f: EqualTo => Seq(f.attribute.toParameterName(f.value), f.value) case f: GreaterThan => Seq(f.attribute.toParameterName(f.value), f.value) case f: GreaterThanOrEqual => Seq(f.attribute.toParameterName(f.value), f.value) case f: LessThan => Seq(f.attribute.toParameterName(f.value), f.value) case f: LessThanOrEqual => Seq(f.attribute.toParameterName(f.value), f.value) case f: In => Seq(f.attribute.toParameterName(f.values), f.values) case f: StringStartsWith => Seq(f.attribute.toParameterName(f.value), f.value) case f: StringEndsWith => Seq(f.attribute.toParameterName(f.value), f.value) case f: StringContains => Seq(f.attribute.toParameterName(f.value), f.value) case f: Not => f.child.getAttributeAndValue case _ => Seq() } } } implicit class StructTypeImplicit(structType: StructType) { private def isValidMapOrStructField(field: String, structFieldName: String) = { val value: String = """(`.*`)|([^\.]*)""".r.findFirstIn(field).getOrElse("") structFieldName == value.unquote() || structFieldName == value } def getByName(name: String): Option[StructField] = { val index = structType.fieldIndex(name) if (index > -1) Some(structType(index)) else None } def getFieldIndex(fieldName: String): Long = structType.fields.map(_.name).indexOf(fieldName) def getMissingFields(fields: Set[String]): Set[String] = fields .map(field => { val maybeField = structType .find(structField => { structField.dataType match { case _: MapType => isValidMapOrStructField(field, structField.name) case _: StructType => isValidMapOrStructField(field, structField.name) case _ => structField.name == field.unquote() || structField.name == field } }) field -> maybeField.isDefined }) .filterNot(e => e._2) .map(e => e._1) } implicit class AggregationImplicit(aggregation: Aggregation) { def groupByCols(): Array[Expression] = ReflectionUtils.groupByCols(aggregation) } implicit class MapImplicit[K, V](map: Map[String, V]) { private def innerFlattenMap(map: Map[String, _], prefix: String): Seq[(String, AnyRef)] = map .toSeq .flatMap(t => { val key: String = if (prefix != "") s"$prefix.${t._1}" else t._1 t._2 match { case nestedMap: Map[String, _] => innerFlattenMap(nestedMap, key) case nestedMap: java.util.Map[String, _] => innerFlattenMap(nestedMap.asScala.toMap, key) case _ => Seq((key, t._2.asInstanceOf[AnyRef])) } }) .toList def flattenMap(prefix: String = "", groupDuplicateKeys: Boolean = false): Map[String, AnyRef] = innerFlattenMap(map, prefix) .groupBy(_._1) .mapValues(seq => if (groupDuplicateKeys && seq.size > 1) seq.map(_._2).asJava else seq.last._2) .toMap def flattenKeys(prefix: String = ""): Seq[String] = map .flatMap(t => { val key: String = if (prefix != "") s"$prefix.${t._1}" else t._1 t._2 match { case nestedMap: Map[String, _] => nestedMap.flattenKeys(key) case nestedMap: java.util.Map[String, _] => nestedMap.asScala.toMap.flattenKeys(key) case _ => Seq(key) } }) .toList } implicit class StringMapImplicits(map: Map[String, String]) { private val propertyMapper = new ObjectMapper() propertyMapper.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, true) private def nestingMap(data: Map[String, String]): java.util.Map[String, Any] = { val map = new java.util.HashMap[String, Any](); data.foreach(t => { val splitted = t._1.split("\\.") if (splitted.size == 1) { val value = try { propertyMapper.readValue[Any](t._2, classOf[Any]) } catch { case _: JsonParseException => t._2 } map.put(t._1, value) } else { if (map.containsKey(splitted.head)) { val value = map.get(splitted.head).asInstanceOf[java.util.Map[String, Any]] value.putAll(nestingMap(Map(splitted.drop(1).mkString(".") -> t._2))) map.put(splitted.head, value) } else { map.put(splitted.head, nestingMap(Map(splitted.drop(1).mkString(".") -> t._2))) } } }) map } def toNestedJavaMap: java.util.Map[String, Any] = nestingMap(map) } implicit class ValueImplicits(value: Value) { def asOptionalLong(): Option[Long] = { if (value.isNull) { Option.empty } else { Option(value.asLong()) } } } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.jetbrains.annotations.TestOnly import org.neo4j.connectors.authn.AuthenticationToken import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory import org.neo4j.connectors.common.driver.reauth.ReAuthDriverFactory import org.neo4j.driver.Config.TrustStrategy import org.neo4j.driver._ import org.neo4j.driver.exceptions.Neo4jException import org.neo4j.driver.net.ServerAddress import org.neo4j.spark.util.Neo4jImplicits.StringMapImplicits import java.io.File import java.net.URI import java.time.Duration import java.util import java.util.Locale import java.util.ServiceLoader import java.util.UUID import java.util.concurrent.TimeUnit import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.language.implicitConversions class Neo4jOptions(private val options: java.util.Map[String, String]) extends Serializable with Logging { import Neo4jOptions._ import QueryType._ def asMap() = new util.HashMap[String, String](options) private def getRequiredParameter(parameter: String): String = { if (!options.containsKey(parameter) || options.get(parameter).isEmpty) { throw new IllegalArgumentException(s"Parameter '$parameter' is required") } options.get(parameter) } private def getParameter(parameter: String, defaultValue: String = ""): String = getParameterOption(parameter).getOrElse(defaultValue) private def getParameterOption(parameter: String): Option[String] = Some(options.get(parameter)) .flatMap(Option(_)) // to turn null into None .map(_.trim) private def getAuthenticationParameters: Map[String, String] = { val authType = getParameter(AUTH_TYPE, DEFAULT_AUTH_TYPE) val authNamespace = s"$AUTH.$authType" val providedParameters = options.asScala .filterKeys(_.startsWith(authNamespace)) .map(t => (t._1.substring(authNamespace.length + 1), t._2)) .toMap DEFAULT_AUTH_PARAMETERS ++ providedParameters } val saveMode: String = getParameter(SAVE_MODE, DEFAULT_SAVE_MODE.toString) val pushdownFiltersEnabled: Boolean = getParameter(PUSHDOWN_FILTERS_ENABLED, DEFAULT_PUSHDOWN_FILTERS_ENABLED.toString).toBoolean val pushdownColumnsEnabled: Boolean = getParameter(PUSHDOWN_COLUMNS_ENABLED, DEFAULT_PUSHDOWN_COLUMNS_ENABLED.toString).toBoolean val pushdownAggregateEnabled: Boolean = getParameter(PUSHDOWN_AGGREGATE_ENABLED, DEFAULT_PUSHDOWN_AGGREGATE_ENABLED.toString).toBoolean val pushdownLimitEnabled: Boolean = getParameter(PUSHDOWN_LIMIT_ENABLED, DEFAULT_PUSHDOWN_LIMIT_ENABLED.toString).toBoolean val pushdownTopNEnabled: Boolean = getParameter(PUSHDOWN_TOPN_ENABLED, DEFAULT_PUSHDOWN_TOPN_ENABLED.toString).toBoolean val schemaMetadata: Neo4jSchemaMetadata = initSchemaMetadata private def initSchemaMetadata = { val deprecatedSchemaOptimization = OptimizationType .withCaseInsensitiveName(getParameter(SCHEMA_OPTIMIZATION_TYPE, DEFAULT_OPTIMIZATION_TYPE.toString).toUpperCase) if (deprecatedSchemaOptimization != OptimizationType.NONE) { logWarning( s""" |Option `$SCHEMA_OPTIMIZATION_TYPE` is deprecated and will be removed in future implementations, |please move to one of the following depending on your use case: |- `$SCHEMA_OPTIMIZATION_NODE_KEY` |- `$SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY` |""".stripMargin ) } val nodeConstr: ConstraintsOptimizationType.Value = ConstraintsOptimizationType .withCaseInsensitiveName(getParameter( SCHEMA_OPTIMIZATION_NODE_KEY, DEFAULT_SCHEMA_OPTIMIZATION_NODE_KEY.toString ).trim) val relConstr: ConstraintsOptimizationType.Value = ConstraintsOptimizationType .withCaseInsensitiveName(getParameter( SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY, DEFAULT_SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY.toString ).trim) val schemaConstraints = getParameter(SCHEMA_OPTIMIZATION, DEFAULT_SCHEMA_OPTIMIZATION.toString) .split(",") .map(_.trim) .map(SchemaConstraintsOptimizationType.withCaseInsensitiveName) .toSet Neo4jSchemaMetadata( getParameter(SCHEMA_FLATTEN_LIMIT, DEFAULT_SCHEMA_FLATTEN_LIMIT.toString).toInt, SchemaStrategy.withCaseInsensitiveName(getParameter( SCHEMA_STRATEGY, DEFAULT_SCHEMA_STRATEGY.toString ).toUpperCase), deprecatedSchemaOptimization, Neo4jSchemaOptimizations(nodeConstr, relConstr, schemaConstraints), getParameter(SCHEMA_MAP_GROUP_DUPLICATE_KEYS, DEFAULT_MAP_GROUP_DUPLICATE_KEYS.toString).toBoolean ) } val indexAwait = getParameter(INDEX_AWAIT_TIMEOUT_SEC, DEFAULT_INDEX_AWAIT_TIMEOUT_SEC.toString).toInt val query: Neo4jQueryOptions = ( getParameter(QUERY.toString.toLowerCase), getParameter(LABELS.toString.toLowerCase), getParameter(RELATIONSHIP.toString.toLowerCase()), getParameter(GDS.toString.toLowerCase()) ) match { case (query, "", "", "") => Neo4jQueryOptions(QUERY, query) case ("", label, "", "") => { val parsed = if (label.trim.startsWith(":")) label.substring(1) else label Neo4jQueryOptions(LABELS, parsed) } case ("", "", relationship, "") => Neo4jQueryOptions(RELATIONSHIP, relationship) case ("", "", "", gds) => Neo4jQueryOptions(GDS, gds) case _ => throw new IllegalArgumentException( s"You need to specify just one of these options: ${ QueryType.values.toSeq.map(value => s"'${value.toString.toLowerCase()}'") .sorted.mkString(", ") }" ) } val connection: Neo4jDriverOptions = Neo4jDriverOptions( getRequiredParameter(URL), getParameter(AUTH_TYPE, DEFAULT_AUTH_TYPE), getAuthenticationParameters, getParameter(ENCRYPTION_ENABLED, DEFAULT_ENCRYPTION_ENABLED.toString).toBoolean, Option(getParameter(ENCRYPTION_TRUST_STRATEGY, null)), getParameter(ENCRYPTION_CA_CERTIFICATE_PATH, DEFAULT_EMPTY), getParameter(CONNECTION_MAX_LIFETIME_MSECS, DEFAULT_CONNECTION_MAX_LIFETIME_MSECS.toString).toInt, getParameter(CONNECTION_ACQUISITION_TIMEOUT_MSECS, DEFAULT_TIMEOUT.toString).toInt, getParameter( CONNECTION_LIVENESS_CHECK_TIMEOUT_MSECS, DEFAULT_CONNECTION_LIVENESS_CHECK_TIMEOUT_MSECS.toString ).toInt, getParameter(CONNECTION_TIMEOUT_MSECS, DEFAULT_TIMEOUT.toString).toInt ) val session: Neo4jSessionOptions = Neo4jSessionOptions( getParameter(DATABASE, DEFAULT_EMPTY), AccessMode.valueOf(getParameter(ACCESS_MODE, DEFAULT_ACCESS_MODE.toString).toUpperCase()) ) val nodeMetadata: Neo4jNodeMetadata = initNeo4jNodeMetadata() val legacyTypeConversionEnabled: Boolean = getParameter( TYPE_CONVERSION, DEFAULT_TYPE_CONVERSION ).toLowerCase(Locale.ROOT).equals("legacy") private def mapPropsString(strOpt: Option[String]): Option[Map[String, String]] = strOpt.map(str => str.split(",") .map(_.trim) .filter(_.nonEmpty) .map(s => { val keys = if (s.startsWith("`")) { val pattern = "`[^`]+`".r val groups = pattern findAllIn s groups .map(_.replaceAll("`", "")) .toArray } else { s.split(":") } if (keys.length == 2) { (keys(0), keys(1)) } else { (keys(0), keys(0)) } }) .toMap ) private def initNeo4jNodeMetadata( nodeKeysString: String = getParameter(NODE_KEYS, ""), labelsString: String = query.value, nodePropsString: String = "", skipNullKeys: Boolean = getParameter(NODE_KEYS_SKIP_NULLS, "false").toBoolean ): Neo4jNodeMetadata = { val nodeKeys = mapPropsString(Some(nodeKeysString)).getOrElse(Map.empty[String, String]) val nodeProps = mapPropsString(Some(nodePropsString)).getOrElse(Map.empty[String, String]) val labels = labelsString .split(":") .map(_.trim) .filter(_.nonEmpty) Neo4jNodeMetadata(labels, nodeKeys, nodeProps, skipNullKeys) } val transactionSettings: Neo4jTransactionSettings = initNeo4jTransactionSettings() val script: Array[String] = getParameter(SCRIPT) .split(";") .map(_.trim) .filterNot(_.isEmpty) private def initNeo4jTransactionSettings(): Neo4jTransactionSettings = { val retries = getParameter(TRANSACTION_RETRIES, DEFAULT_TRANSACTION_RETRIES.toString).toInt val failOnTransactionCodes = getParameter(TRANSACTION_CODES_FAIL, DEFAULT_EMPTY) .split(",") .map(_.trim) .filter(_.nonEmpty) .toSet val batchSize = getParameter(BATCH_SIZE, DEFAULT_BATCH_SIZE.toString).toInt val retryTimeout = getParameter(TRANSACTION_RETRY_TIMEOUT, DEFAULT_TRANSACTION_RETRY_TIMEOUT.toString).toInt Neo4jTransactionSettings(retries, failOnTransactionCodes, batchSize, retryTimeout) } val relationshipMetadata: Neo4jRelationshipMetadata = initNeo4jRelationshipMetadata() private def initNeo4jRelationshipMetadata(): Neo4jRelationshipMetadata = { val source = initNeo4jNodeMetadata( getParameter(RELATIONSHIP_SOURCE_NODE_KEYS, ""), getParameter(RELATIONSHIP_SOURCE_LABELS, ""), getParameter(RELATIONSHIP_SOURCE_NODE_PROPS, ""), getParameter(RELATIONSHIP_SOURCE_NODE_KEYS_SKIP_NULLS, "false").toBoolean ) val target = initNeo4jNodeMetadata( getParameter(RELATIONSHIP_TARGET_NODE_KEYS, ""), getParameter(RELATIONSHIP_TARGET_LABELS, ""), getParameter(RELATIONSHIP_TARGET_NODE_PROPS, ""), getParameter(RELATIONSHIP_TARGET_NODE_KEYS_SKIP_NULLS, "false").toBoolean ) val nodeMap = getParameter(RELATIONSHIP_NODES_MAP, DEFAULT_RELATIONSHIP_NODES_MAP.toString).toBoolean val relProps = mapPropsString(getParameterOption(RELATIONSHIP_PROPERTIES)) val writeStrategy = RelationshipSaveStrategy.withCaseInsensitiveName(getParameter( RELATIONSHIP_SAVE_STRATEGY, DEFAULT_RELATIONSHIP_SAVE_STRATEGY.toString ).toUpperCase) val sourceSaveMode = NodeSaveMode.withCaseInsensitiveName(getParameter( RELATIONSHIP_SOURCE_SAVE_MODE, DEFAULT_RELATIONSHIP_SOURCE_SAVE_MODE.toString )) val targetSaveMode = NodeSaveMode.withCaseInsensitiveName(getParameter( RELATIONSHIP_TARGET_SAVE_MODE, DEFAULT_RELATIONSHIP_TARGET_SAVE_MODE.toString )) val relationshipKeys = mapPropsString(getParameterOption(RELATIONSHIP_KEYS)).getOrElse(Map.empty) Neo4jRelationshipMetadata( source, target, sourceSaveMode, targetSaveMode, relProps, query.value, nodeMap, writeStrategy, relationshipKeys, getParameter(RELATIONSHIP_KEYS_SKIP_NULLS, "false").toBoolean ) } private def initNeo4jQueryMetadata(): Neo4jQueryMetadata = Neo4jQueryMetadata( query.value.trim, getParameter(QUERY_COUNT, "").trim ) val queryMetadata: Neo4jQueryMetadata = initNeo4jQueryMetadata() private def initNeo4jGdsMetadata(): Neo4jGdsMetadata = Neo4jGdsMetadata( options.asScala .filterKeys(k => k.startsWith("gds.")) .map(t => (t._1.substring("gds.".length), t._2)) .toMap .toNestedJavaMap ) val gdsMetadata: Neo4jGdsMetadata = initNeo4jGdsMetadata() val partitions: Int = getParameter(PARTITIONS, DEFAULT_PARTITIONS.toString).toInt val streamingOrderBy: String = getParameter(ORDER_BY, getParameter(STREAMING_PROPERTY_NAME)) val apocConfig: Neo4jApocConfig = Neo4jApocConfig(options.asScala .filterKeys(_.startsWith("apoc.")) .mapValues(Neo4jUtil.mapper.readValue(_, classOf[java.util.Map[String, AnyRef]]).asScala) .toMap) def getTableName: String = query.queryType match { case QueryType.LABELS => s"table_${nodeMetadata.labels.mkString("-")}" case QueryType.RELATIONSHIP => s"table_${relationshipMetadata.source.labels.mkString("-")}" + s"_${relationshipMetadata.relationshipType}" + s"_${relationshipMetadata.target.labels.mkString("-")}" case _ => s"table_query_${UUID.randomUUID()}" } val streamingOptions: Neo4jStreamingOptions = Neo4jStreamingOptions( getParameter(STREAMING_PROPERTY_NAME), StreamingFrom.withCaseInsensitiveName(getParameter(STREAMING_FROM, DEFAULT_STREAMING_FROM.toString)), getParameter(STREAMING_QUERY_OFFSET) ) def toNeo4jTransactionConfig: TransactionConfig = { val timeout = getParameter(TRANSACTION_TIMEOUT_MSECS, DEFAULT_TRANSACTION_TIMEOUT) val builder = TransactionConfig.builder() if (timeout != null) { val duration = Duration.ofMillis(timeout.toInt) builder.withTimeout(duration) } builder.build() } } case class Neo4jStreamingOptions( propertyName: String, from: StreamingFrom.Value, queryOffset: String ) case class Neo4jApocConfig(procedureConfigMap: Map[String, AnyRef]) case class Neo4jSchemaOptimizations( nodeConstraint: ConstraintsOptimizationType.Value, relConstraint: ConstraintsOptimizationType.Value, schemaConstraints: Set[SchemaConstraintsOptimizationType.Value] ) case class Neo4jSchemaMetadata( flattenLimit: Int, strategy: SchemaStrategy.Value, optimizationType: OptimizationType.Value, optimization: Neo4jSchemaOptimizations, mapGroupDuplicateKeys: Boolean ) case class Neo4jTransactionSettings( retries: Int, failOnTransactionCodes: Set[String], batchSize: Int, retryTimeout: Long ) { def shouldFailOn(exception: Throwable): Boolean = { exception match { case e: Neo4jException => failOnTransactionCodes.contains(e.code()) case _ => false } } } case class Neo4jNodeMetadata( labels: Seq[String], nodeKeys: Map[String, String], properties: Map[String, String], skipNullKeys: Boolean = false ) { def includesProperty(name: String): Boolean = nodeKeys.contains(name) || properties.contains(name) } case class Neo4jRelationshipMetadata( source: Neo4jNodeMetadata, target: Neo4jNodeMetadata, sourceSaveMode: NodeSaveMode.Value, targetSaveMode: NodeSaveMode.Value, properties: Option[Map[String, String]], relationshipType: String, nodeMap: Boolean, saveStrategy: RelationshipSaveStrategy.Value, relationshipKeys: Map[String, String], skipNullKeys: Boolean = false ) case class Neo4jQueryMetadata(query: String, queryCount: String) case class Neo4jGdsMetadata(parameters: util.Map[String, Any]) case class Neo4jQueryOptions(queryType: QueryType.Value, value: String) case class Neo4jSessionOptions(database: String, accessMode: AccessMode = AccessMode.READ) { def toNeo4jSession(): SessionConfig = { val builder = SessionConfig.builder() .withDefaultAccessMode(accessMode) if (database != null && database != "") { builder.withDatabase(database) } builder.build() } } case class Neo4jDriverOptions( url: String, auth: String, authParameters: Map[String, String], encryption: Boolean, trustStrategy: Option[String], certificatePath: String, lifetime: Int, acquisitionTimeout: Int, livenessCheckTimeout: Int, connectionTimeout: Int ) extends Serializable { def createDriver(): Driver = { val (url, _) = connectionUrls ReAuthDriverFactory.driver(url, createAuthTokenSupplier, toDriverConfig) } private def toDriverConfig: Config = { val builder = Config.builder() .withUserAgent(s"neo4j-${Neo4jUtil.connectorEnv}-connector/${Neo4jUtil.connectorVersion}") .withLogging(Logging.slf4j()) if (lifetime > -1) builder.withMaxConnectionLifetime(lifetime, TimeUnit.MILLISECONDS) if (acquisitionTimeout > -1) builder.withConnectionAcquisitionTimeout(acquisitionTimeout, TimeUnit.MILLISECONDS) if (livenessCheckTimeout > -1) builder.withConnectionLivenessCheckTimeout(livenessCheckTimeout, TimeUnit.MILLISECONDS) if (connectionTimeout > -1) builder.withConnectionTimeout(connectionTimeout, TimeUnit.MILLISECONDS) val (primaryUrl, resolvers) = connectionUrls primaryUrl.getScheme match { case "neo4j+s" | "neo4j+ssc" | "bolt+s" | "bolt+ssc" => () case _ => { if (!encryption) { builder.withoutEncryption() } else { builder.withEncryption() } trustStrategy .map(Config.TrustStrategy.Strategy.valueOf) .map { case TrustStrategy.Strategy.TRUST_ALL_CERTIFICATES => TrustStrategy.trustAllCertificates() case TrustStrategy.Strategy.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES => TrustStrategy.trustSystemCertificates() case TrustStrategy.Strategy.TRUST_CUSTOM_CA_SIGNED_CERTIFICATES => TrustStrategy.trustCustomCertificateSignedBy(new File(certificatePath)) }.foreach(builder.withTrustStrategy) } } if (resolvers.nonEmpty) { builder.withResolver(_ => resolvers.asJava) } builder.build() } // public only for testing purposes @TestOnly def connectionUrls: (URI, Set[ServerAddress]) = { val urls = url.split(",").toList val resolved = urls .drop(1) .map(_.trim) .map(URI.create) .map(uri => ServerAddress.of(uri.getHost, if (uri.getPort > -1) uri.getPort else 7687)) .toSet (URI.create(urls.head.trim), resolved) } private def createAuthTokenSupplier: Supplier[AuthenticationToken] = { if (auth == null || auth.isEmpty) { throw new IllegalArgumentException(s"Authentication type name is required") } val supplierFactories = ServiceLoader.load( classOf[AuthenticationTokenSupplierFactory], getClass.getClassLoader ).iterator() .asScala .toList val filteredSupplierFactories = supplierFactories.filter(s => s.getName != null && s.getName.equalsIgnoreCase(auth)) if (filteredSupplierFactories.isEmpty) { throw new IllegalArgumentException( s"Authentication method '$auth' is not supported. Supported authentication methods are: ${supplierFactories.map(_.getName).mkString(", ")}" ) } if (filteredSupplierFactories.size > 1) { throw new IllegalArgumentException( s"Multiple implementation for authentication type '$auth' are found" ) } val username = authParameters.get("username") val password = authParameters.get("password") filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava) } } object Neo4jOptions { // connection options val URL = "url" // auth val AUTH = "authentication" val AUTH_TYPE = "authentication.type" // driver val ENCRYPTION_ENABLED = "encryption.enabled" val ENCRYPTION_TRUST_STRATEGY = "encryption.trust.strategy" val ENCRYPTION_CA_CERTIFICATE_PATH = "encryption.ca.certificate.path" val CONNECTION_MAX_LIFETIME_MSECS = "connection.max.lifetime.msecs" val CONNECTION_LIVENESS_CHECK_TIMEOUT_MSECS = "connection.liveness.timeout.msecs" val CONNECTION_ACQUISITION_TIMEOUT_MSECS = "connection.acquisition.timeout.msecs" val CONNECTION_TIMEOUT_MSECS = "connection.timeout.msecs" val TRANSACTION_TIMEOUT_MSECS = "db.transaction.timeout" // session options val DATABASE = "database" val ACCESS_MODE = "access.mode" val SAVE_MODE = "save.mode" val PUSHDOWN_FILTERS_ENABLED = "pushdown.filters.enabled" val PUSHDOWN_COLUMNS_ENABLED = "pushdown.columns.enabled" val PUSHDOWN_AGGREGATE_ENABLED = "pushdown.aggregate.enabled" val PUSHDOWN_LIMIT_ENABLED = "pushdown.limit.enabled" val PUSHDOWN_TOPN_ENABLED = "pushdown.topN.enabled" // schema options val SCHEMA_STRATEGY = "schema.strategy" val SCHEMA_FLATTEN_LIMIT = "schema.flatten.limit" // deprecated in favor of... val SCHEMA_OPTIMIZATION_TYPE = "schema.optimization.type" // ...these options val SCHEMA_OPTIMIZATION = "schema.optimization" val SCHEMA_OPTIMIZATION_NODE_KEY = "schema.optimization.node.keys" val SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY = "schema.optimization.relationship.keys" // map aggregation val SCHEMA_MAP_GROUP_DUPLICATE_KEYS = "schema.map.group.duplicate.keys" // index options val INDEX_AWAIT_TIMEOUT_SEC = "index.await.timeout" // partitions val PARTITIONS = "partitions" // orderBy val ORDER_BY = "orderBy" // skip.nulls postfix val SKIP_NULLS = "skip.nulls" // Node Metadata val NODE_KEYS = "node.keys" val NODE_KEYS_SKIP_NULLS = s"${NODE_KEYS}.${SKIP_NULLS}" val NODE_PROPS = "node.properties" val BATCH_SIZE = "batch.size" val SUPPORTED_SAVE_MODES = Seq(SaveMode.Overwrite, SaveMode.ErrorIfExists, SaveMode.Append) // Relationship Metadata val RELATIONSHIP_SOURCE_LABELS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.source.${QueryType.LABELS.toString.toLowerCase}" val RELATIONSHIP_SOURCE_NODE_KEYS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.source.$NODE_KEYS" val RELATIONSHIP_SOURCE_NODE_KEYS_SKIP_NULLS = s"${RELATIONSHIP_SOURCE_NODE_KEYS}.${SKIP_NULLS}" val RELATIONSHIP_SOURCE_NODE_PROPS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.source.$NODE_PROPS" val RELATIONSHIP_SOURCE_SAVE_MODE = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.source.$SAVE_MODE" val RELATIONSHIP_TARGET_LABELS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.target.${QueryType.LABELS.toString.toLowerCase}" val RELATIONSHIP_TARGET_NODE_KEYS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.target.$NODE_KEYS" val RELATIONSHIP_TARGET_NODE_KEYS_SKIP_NULLS = s"${RELATIONSHIP_TARGET_NODE_KEYS}.${SKIP_NULLS}" val RELATIONSHIP_TARGET_NODE_PROPS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.target.$NODE_PROPS" val RELATIONSHIP_TARGET_SAVE_MODE = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.target.$SAVE_MODE" val RELATIONSHIP_PROPERTIES = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.properties" val RELATIONSHIP_NODES_MAP = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.nodes.map" val RELATIONSHIP_SAVE_STRATEGY = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.save.strategy" val RELATIONSHIP_KEYS = s"${QueryType.RELATIONSHIP.toString.toLowerCase}.keys" val RELATIONSHIP_KEYS_SKIP_NULLS = s"${RELATIONSHIP_KEYS}.${SKIP_NULLS}" // Query metadata val QUERY_COUNT = "query.count" // Transaction Metadata val TRANSACTION_RETRIES = "transaction.retries" val TRANSACTION_RETRY_TIMEOUT = "transaction.retry.timeout" val TRANSACTION_CODES_FAIL = "transaction.codes.fail" // Streaming val STREAMING_PROPERTY_NAME = "streaming.property.name" val STREAMING_FROM = "streaming.from" val STREAMING_QUERY_OFFSET = "streaming.query.offset" val SCRIPT = "script" // Data conversion val TYPE_CONVERSION = "type.conversion" // defaults val DEFAULT_EMPTY = "" val DEFAULT_TIMEOUT: Int = -1 val DEFAULT_ACCESS_MODE = AccessMode.READ val DEFAULT_AUTH_TYPE = "basic" val DEFAULT_ENCRYPTION_ENABLED = false val DEFAULT_SCHEMA_FLATTEN_LIMIT = 10 val DEFAULT_BATCH_SIZE = 5000 val DEFAULT_TRANSACTION_RETRIES = 3 val DEFAULT_TRANSACTION_RETRY_TIMEOUT = 0 val DEFAULT_TRANSACTION_TIMEOUT = null val DEFAULT_RELATIONSHIP_NODES_MAP = false val DEFAULT_SCHEMA_STRATEGY = SchemaStrategy.SAMPLE val DEFAULT_SCHEMA_OPTIMIZATION_NODE_KEY = ConstraintsOptimizationType.NONE val DEFAULT_SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY = ConstraintsOptimizationType.NONE val DEFAULT_SCHEMA_OPTIMIZATION = SchemaConstraintsOptimizationType.NONE val DEFAULT_RELATIONSHIP_SAVE_STRATEGY: RelationshipSaveStrategy.Value = RelationshipSaveStrategy.NATIVE val DEFAULT_RELATIONSHIP_SOURCE_SAVE_MODE: NodeSaveMode.Value = NodeSaveMode.Match val DEFAULT_RELATIONSHIP_TARGET_SAVE_MODE: NodeSaveMode.Value = NodeSaveMode.Match val DEFAULT_PUSHDOWN_FILTERS_ENABLED = true val DEFAULT_PUSHDOWN_COLUMNS_ENABLED = true val DEFAULT_PUSHDOWN_AGGREGATE_ENABLED = true val DEFAULT_PUSHDOWN_LIMIT_ENABLED = true val DEFAULT_PUSHDOWN_TOPN_ENABLED = true val DEFAULT_PARTITIONS = 1 val DEFAULT_OPTIMIZATION_TYPE = OptimizationType.NONE val DEFAULT_SAVE_MODE = SaveMode.Overwrite val DEFAULT_STREAMING_FROM = StreamingFrom.NOW // Default values optimizations for Aura please look at: https://aura.support.neo4j.com/hc/en-us/articles/1500002493281-Neo4j-Java-driver-settings-for-Aura val DEFAULT_CONNECTION_MAX_LIFETIME_MSECS = Duration.ofMinutes(8).toMillis val DEFAULT_CONNECTION_LIVENESS_CHECK_TIMEOUT_MSECS = Duration.ofMinutes(2).toMillis val DEFAULT_MAP_GROUP_DUPLICATE_KEYS = false val DEFAULT_INDEX_AWAIT_TIMEOUT_SEC = 300 var DEFAULT_AUTH_PARAMETERS: Map[String, String] = Seq("username", "password", "ticket", "principal", "credentials", "realm", "scheme", "token") .map(name => name -> DEFAULT_EMPTY).toMap private val DEFAULT_TYPE_CONVERSION = "default" def fromSession(sparkSession: Option[SparkSession], options: java.util.Map[String, String]): Neo4jOptions = { val sessionLevelOptions = sparkSession .map { _.conf .getAll .filterKeys(k => k.startsWith("neo4j.")) .map { elem => (elem._1.substring("neo4j.".length), elem._2) } .toMap } .getOrElse(Map.empty) new Neo4jOptions((sessionLevelOptions ++ options.asScala).asJava) } } class CaseInsensitiveEnumeration extends Enumeration { def withCaseInsensitiveName(s: String): Value = { values.find(_.toString.toLowerCase() == s.toLowerCase).getOrElse( throw new NoSuchElementException(s"No value found for '$s'") ) } } object StreamingFrom extends CaseInsensitiveEnumeration { val ALL, NOW = Value class StreamingFromValue(value: Value) { def value(): Long = value match { case ALL => -1L case NOW => System.currentTimeMillis() } } implicit def valToStreamingFromValue(value: Value): StreamingFromValue = new StreamingFromValue(value) } object StorageType extends CaseInsensitiveEnumeration { val NEO4J, SPARK = Value } object QueryType extends CaseInsensitiveEnumeration { val QUERY, LABELS, RELATIONSHIP, GDS = Value } object RelationshipSaveStrategy extends CaseInsensitiveEnumeration { val NATIVE, KEYS = Value } object NodeSaveMode extends CaseInsensitiveEnumeration { val Overwrite, ErrorIfExists, Match, Append = Value def fromSaveMode(saveMode: SaveMode): Value = { saveMode match { case SaveMode.Overwrite => Overwrite case SaveMode.ErrorIfExists => ErrorIfExists case SaveMode.Append => Append case _ => throw new IllegalArgumentException(s"SaveMode $saveMode not supported") } } } object SchemaStrategy extends CaseInsensitiveEnumeration { val STRING, SAMPLE = Value } object OptimizationType extends CaseInsensitiveEnumeration { val INDEX, NODE_CONSTRAINTS, NONE = Value } object ConstraintsOptimizationType extends CaseInsensitiveEnumeration { val KEY, UNIQUE, NONE = Value } object SchemaConstraintsOptimizationType extends CaseInsensitiveEnumeration { val TYPE, EXISTS, NONE = Value } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/Neo4jUtil.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import com.fasterxml.jackson.core.JsonGenerator import com.fasterxml.jackson.databind.JsonSerializer import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.SerializerProvider import com.fasterxml.jackson.databind.module.SimpleModule import org.apache.spark.sql.sources._ import org.neo4j.caniuse.Neo4j import org.neo4j.cypherdsl.core._ import org.neo4j.driver.Session import org.neo4j.driver.Transaction import org.neo4j.driver.internal.retry.ExponentialBackoffRetryLogic import org.neo4j.driver.types.Entity import org.neo4j.driver.types.Path import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits import org.neo4j.spark.util.Neo4jImplicits._ import org.slf4j.Logger import java.time.temporal.Temporal import java.util.Properties import scala.annotation.tailrec object Neo4jUtil { val NODE_ALIAS = "n" private val INTERNAL_ID_FIELD_NAME = "id" val INTERNAL_ID_FIELD = s"<${INTERNAL_ID_FIELD_NAME}>" private val INTERNAL_LABELS_FIELD_NAME = "labels" val INTERNAL_LABELS_FIELD = s"<${INTERNAL_LABELS_FIELD_NAME}>" val INTERNAL_REL_ID_FIELD = s"" val INTERNAL_REL_TYPE_FIELD = "" val RELATIONSHIP_SOURCE_ALIAS = "source" val RELATIONSHIP_TARGET_ALIAS = "target" val INTERNAL_REL_SOURCE_ID_FIELD = s"<${RELATIONSHIP_SOURCE_ALIAS}.${INTERNAL_ID_FIELD_NAME}>" val INTERNAL_REL_TARGET_ID_FIELD = s"<${RELATIONSHIP_TARGET_ALIAS}.${INTERNAL_ID_FIELD_NAME}>" val INTERNAL_REL_SOURCE_LABELS_FIELD = s"<${RELATIONSHIP_SOURCE_ALIAS}.${INTERNAL_LABELS_FIELD_NAME}>" val INTERNAL_REL_TARGET_LABELS_FIELD = s"<${RELATIONSHIP_TARGET_ALIAS}.${INTERNAL_LABELS_FIELD_NAME}>" val RELATIONSHIP_ALIAS = "rel" private val properties = new Properties() properties.load(Thread.currentThread().getContextClassLoader.getResourceAsStream("neo4j-spark-connector.properties")) def closeSafely(autoCloseable: AutoCloseable, logger: Logger = null): Unit = { try { autoCloseable match { case s: Session => if (s.isOpen) s.close() case t: Transaction => if (t.isOpen) t.close() case null => () case _ => autoCloseable.close() } } catch { case t: Throwable => if (logger != null) logger .warn(s"Cannot close ${autoCloseable.getClass.getSimpleName} because of the following exception:", t) } } val mapper = new ObjectMapper() private val module = new SimpleModule("Neo4jApocSerializer") module.addSerializer( classOf[Path], new JsonSerializer[Path]() { override def serialize(path: Path, jsonGenerator: JsonGenerator, serializerProvider: SerializerProvider): Unit = jsonGenerator.writeString(path.toString) } ) module.addSerializer( classOf[Entity], new JsonSerializer[Entity]() { override def serialize( entity: Entity, jsonGenerator: JsonGenerator, serializerProvider: SerializerProvider ): Unit = jsonGenerator.writeObject(entity.toMap) } ) module.addSerializer( classOf[Temporal], new JsonSerializer[Temporal]() { override def serialize( entity: Temporal, jsonGenerator: JsonGenerator, serializerProvider: SerializerProvider ): Unit = jsonGenerator.writeRaw(entity.toString) } ) mapper.registerModule(module) def isLong(str: String): Boolean = { if (str == null) { false } else { try { str.trim.toLong true } catch { case _: NumberFormatException => false } } } def connectorVersion: String = properties.getOrDefault("version", "UNKNOWN").toString def connectorEnv: String = Option(System.getProperty("neo4j.spark.platform")) .getOrElse(defaultConnectorEnv) private def defaultConnectorEnv: String = Option(System.getenv("DATABRICKS_RUNTIME_VERSION")) .map(_ => "databricks") .getOrElse("spark") def getCorrectProperty(container: PropertyContainer, attribute: String): Property = { container.property(attribute.split('.'): _*) } def paramsFromFilters(filters: Array[Filter]): Map[String, Any] = { filters.flatMap(f => f.flattenFilters).map(_.getAttributeAndValue) .filter(_.nonEmpty) .map(valAndAtt => valAndAtt.head.toString.unquote() -> toParamValue(valAndAtt(1))) .toMap } def toParamValue(value: Any): Any = { value match { case date: java.sql.Date => date.toString case timestamp: java.sql.Timestamp => timestamp.toLocalDateTime case _ => value } } def valueToCypherExpression(attribute: String, value: Any): Expression = { val parameter = Cypher.parameter(attribute.toParameterName(value)) value match { case d: java.sql.Date => Functions.date(parameter) case t: java.sql.Timestamp => Functions.localdatetime(parameter) case _ => parameter } } def mapSparkFiltersToCypher( filter: Filter, container: PropertyContainer, attributeAlias: Option[String] = None ): Condition = { filter match { case eqns: EqualNullSafe => val parameter = valueToCypherExpression(eqns.attribute, eqns.value) val property = getCorrectProperty(container, attributeAlias.getOrElse(eqns.attribute)) property.isNull.and(parameter.isNull) .or(property.isEqualTo(parameter)) case eq: EqualTo => getCorrectProperty(container, attributeAlias.getOrElse(eq.attribute)) .isEqualTo(valueToCypherExpression(eq.attribute, eq.value)) case gt: GreaterThan => getCorrectProperty(container, attributeAlias.getOrElse(gt.attribute)) .gt(valueToCypherExpression(gt.attribute, gt.value)) case gte: GreaterThanOrEqual => getCorrectProperty(container, attributeAlias.getOrElse(gte.attribute)) .gte(valueToCypherExpression(gte.attribute, gte.value)) case lt: LessThan => getCorrectProperty(container, attributeAlias.getOrElse(lt.attribute)) .lt(valueToCypherExpression(lt.attribute, lt.value)) case lte: LessThanOrEqual => getCorrectProperty(container, attributeAlias.getOrElse(lte.attribute)) .lte(valueToCypherExpression(lte.attribute, lte.value)) case in: In => getCorrectProperty(container, attributeAlias.getOrElse(in.attribute)) .in(valueToCypherExpression(in.attribute, in.values)) case startWith: StringStartsWith => getCorrectProperty(container, attributeAlias.getOrElse(startWith.attribute)) .startsWith(valueToCypherExpression(startWith.attribute, startWith.value)) case endsWith: StringEndsWith => getCorrectProperty(container, attributeAlias.getOrElse(endsWith.attribute)) .endsWith(valueToCypherExpression(endsWith.attribute, endsWith.value)) case contains: StringContains => getCorrectProperty(container, attributeAlias.getOrElse(contains.attribute)) .contains(valueToCypherExpression(contains.attribute, contains.value)) case notNull: IsNotNull => getCorrectProperty(container, attributeAlias.getOrElse(notNull.attribute)).isNotNull case isNull: IsNull => getCorrectProperty(container, attributeAlias.getOrElse(isNull.attribute)).isNull case not: Not => mapSparkFiltersToCypher(not.child, container, attributeAlias).not() case filter @ (_: Filter) => throw new IllegalArgumentException(s"Filter of type `$filter` is not supported.") } } def getStreamingPropertyName(options: Neo4jOptions): String = options.query.queryType match { case QueryType.RELATIONSHIP => s"rel.${options.streamingOptions.propertyName}" case _ => options.streamingOptions.propertyName } def callSchemaService[T]( neo4j: Neo4j, neo4jOptions: Neo4jOptions, jobId: String, filters: Array[Filter], function: SchemaService => T ): T = { val driverCache = new DriverCache(neo4jOptions.connection) val schemaService = new SchemaService(neo4j, neo4jOptions, driverCache, filters) var hasError = false try { function(schemaService) } catch { case e: Throwable => { hasError = true throw e } } finally { schemaService.close() if (hasError) { driverCache.close() } } } @tailrec def isRetryableException(exception: Throwable): Boolean = { if (exception == null) { false } else ExponentialBackoffRetryLogic.isRetryable(exception) || isRetryableException( exception.getCause ) } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/ValidationUtil.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util object ValidationUtil { def isNotEmpty(str: String, message: String) = if (str.isEmpty) { throw new IllegalArgumentException(message) } def isNotBlank(str: String, message: String) = if (str.trim.isEmpty) { throw new IllegalArgumentException(message) } def isBlank(str: String, message: String) = if (!str.trim.isEmpty) { throw new IllegalArgumentException(message) } def isNotEmpty(seq: Seq[_], message: String) = if (seq.isEmpty) { throw new IllegalArgumentException(message) } def isNotEmpty(map: Map[_, _], message: String) = if (map.isEmpty) { throw new IllegalArgumentException(message) } def isTrue(boolean: Boolean, message: String) = if (!boolean) { throw new IllegalArgumentException(message) } def isFalse(boolean: Boolean, message: String) = if (boolean) { throw new IllegalArgumentException(message) } def isNotValid(message: String) = throw new IllegalArgumentException(message) } ================================================ FILE: common/src/main/scala/org/neo4j/spark/util/Validations.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.driver.AccessMode import org.neo4j.driver.summary import org.neo4j.spark.service.Neo4jQueryStrategy import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util import org.neo4j.spark.util.Neo4jImplicits.StructTypeImplicit import java.util.Locale object Validations { def validate(validations: Validation*): Unit = validations.toSet[Validation].foreach(_.validate()) } trait Validation extends Logging { def validate(): Unit def ignoreOption(ignoredOption: String, primaryOption: String): Unit = logWarning(s"Option `$ignoredOption` is not compatible with `$primaryOption` and will be ignored") } case class ValidateSchemaOptions(neo4jOptions: Neo4jOptions, schema: StructType) extends Validation { override def validate(): Unit = { val missingFieldsMap = Map( Neo4jOptions.NODE_KEYS -> schema.getMissingFields(neo4jOptions.nodeMetadata.nodeKeys.keySet), Neo4jOptions.NODE_PROPS -> schema.getMissingFields(neo4jOptions.nodeMetadata.properties.keySet), Neo4jOptions.RELATIONSHIP_PROPERTIES -> schema.getMissingFields( neo4jOptions.relationshipMetadata.properties.getOrElse(Map.empty).keySet ), Neo4jOptions.RELATIONSHIP_SOURCE_NODE_PROPS -> schema.getMissingFields( neo4jOptions.relationshipMetadata.source.properties.keySet ), Neo4jOptions.RELATIONSHIP_SOURCE_NODE_KEYS -> schema.getMissingFields( neo4jOptions.relationshipMetadata.source.nodeKeys.keySet ), Neo4jOptions.RELATIONSHIP_TARGET_NODE_PROPS -> schema.getMissingFields( neo4jOptions.relationshipMetadata.target.properties.keySet ), Neo4jOptions.RELATIONSHIP_TARGET_NODE_KEYS -> schema.getMissingFields( neo4jOptions.relationshipMetadata.target.nodeKeys.keySet ) ) val optionsWithMissingFields = missingFieldsMap.filter(_._2.nonEmpty) if (optionsWithMissingFields.nonEmpty) { throw new IllegalArgumentException( s"""Write failed due to the following errors: |${optionsWithMissingFields.map(field => s" - Schema is missing ${field._2.mkString(", ")} from option `${field._1}`" ).mkString("\n")} | |The option key and value might be inverted.""".stripMargin ) } } } case class ValidateSchemaMetadataWrite(neo4jOptions: Neo4jOptions, saveMode: SaveMode) extends Validation { override def validate(): Unit = { val schemaMetadata = neo4jOptions.schemaMetadata val hasNodeOptimizations = (schemaMetadata.optimizationType == OptimizationType.NODE_CONSTRAINTS || schemaMetadata.optimization.nodeConstraint != ConstraintsOptimizationType.NONE) val hasOptimizations = (schemaMetadata.optimizationType != OptimizationType.NONE && (schemaMetadata.optimization.nodeConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.optimization.relConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.optimization.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE))) if (hasOptimizations) { throw new IllegalArgumentException( s"""You cannot combine `${Neo4jOptions.SCHEMA_OPTIMIZATION_TYPE}` with: |- `${Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY}` |- `${Neo4jOptions.SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY}` |- `${Neo4jOptions.SCHEMA_OPTIMIZATION}` """ ) } neo4jOptions.query.queryType match { case QueryType.QUERY => { neo4jOptions.schemaMetadata.optimizationType match { case OptimizationType.NONE => // are valid case _ => ValidationUtil.isNotValid( s"""With Query Type ${neo4jOptions.query.queryType} you can |only use `${Neo4jOptions.SCHEMA_OPTIMIZATION_TYPE}` |`${OptimizationType.NONE}` |""".stripMargin ) } if (hasOptimizations) { throw new IllegalArgumentException( s"With Query Type ${neo4jOptions.query.queryType} you cannot define any optimization" ) } } case QueryType.LABELS => { if (hasNodeOptimizations) { ValidationUtil.isTrue(saveMode == SaveMode.Overwrite, "This works only with `mode` `Overwrite`") ValidationUtil.isNotEmpty( neo4jOptions.nodeMetadata.nodeKeys, s"${Neo4jOptions.NODE_KEYS} is required to define the constraints" ) } } case QueryType.RELATIONSHIP => { if (hasNodeOptimizations) { ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.source.nodeKeys, s"${Neo4jOptions.RELATIONSHIP_SOURCE_NODE_KEYS} is required to define the constraints" ) ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.target.nodeKeys, s"${Neo4jOptions.RELATIONSHIP_TARGET_NODE_KEYS} is required to define the constraints" ) ValidationUtil.isTrue( neo4jOptions.relationshipMetadata.sourceSaveMode == NodeSaveMode.Overwrite, s"This works only with `${Neo4jOptions.RELATIONSHIP_SOURCE_SAVE_MODE}` `${NodeSaveMode.Overwrite}`" ) ValidationUtil.isTrue( neo4jOptions.relationshipMetadata.targetSaveMode == NodeSaveMode.Overwrite, s"This works only with `${Neo4jOptions.RELATIONSHIP_TARGET_SAVE_MODE}` `${NodeSaveMode.Overwrite}`" ) } if (schemaMetadata.optimization.relConstraint != ConstraintsOptimizationType.NONE) { ValidationUtil.isTrue(saveMode == SaveMode.Overwrite, s"This works only with `mode` `${SaveMode.Overwrite}`") ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.relationshipKeys, s"${Neo4jOptions.RELATIONSHIP_KEYS} is required to define the constraints" ) } } case _ => // do nothing } neo4jOptions.schemaMetadata.optimizationType match { case OptimizationType.NONE => // skip it case _ => neo4jOptions.query.queryType match { case QueryType.LABELS => ValidationUtil.isTrue(saveMode == SaveMode.Overwrite, "This works only with `mode` `SaveMode.Overwrite`") case QueryType.RELATIONSHIP => { ValidationUtil.isTrue( neo4jOptions.relationshipMetadata.sourceSaveMode == NodeSaveMode.Overwrite, s"This works only with `${Neo4jOptions.RELATIONSHIP_SOURCE_SAVE_MODE}` `Overwrite`" ) ValidationUtil.isTrue( neo4jOptions.relationshipMetadata.targetSaveMode == NodeSaveMode.Overwrite, s"This works only with `${Neo4jOptions.RELATIONSHIP_TARGET_SAVE_MODE}` `Overwrite`" ) } } } } } case class ValidateSparkMinVersion(supportedVersions: String*) extends Validation { override def validate(): Unit = { val sparkVersion = SparkSession.getActiveSession .map(_.version) .getOrElse("UNKNOWN") ValidationUtil.isTrue( isSupported(sparkVersion), s"""Your current Spark version $sparkVersion is not supported by the current connector. |Please visit https://neo4j.com/developer/spark/overview/#_spark_compatibility to know which connector version you need. |""".stripMargin ) } def isSupported(sparkVersion: String): Boolean = { val splittedVersion = sparkVersion.split("\\.") if (sparkVersion == "UNKNOWN") return true val versions = supportedVersions .flatMap(_.split("\\.").zip(splittedVersion)) .map(t => try { (t._1.toInt, t._2.toInt) } catch { case _: NumberFormatException => null } ) .filter(p => p != null) for (t <- versions) { val curr = t._2 val supported = t._1 if (curr > supported) { // if Spark current version (step) is greater than the supported one return true // we can assume that the current version is a greater version, so it's supported } if (curr < supported) { // if Spark current version (step) is lower than the supported one return false // we can assume that the current version is a greater version, so it's supported } // if the versions are equal we can check the next step } true // this happens if the two versions are equal } } case class ValidateConnection(neo4jOptions: Neo4jOptions, jobId: String) extends Validation { override def validate(): Unit = { var driverCache: DriverCache = null var hasError = false try { driverCache = new DriverCache(neo4jOptions.connection) driverCache.getOrCreate().verifyConnectivity() } catch { case e: Throwable => { hasError = true throw e } } finally { if (hasError) { Neo4jUtil.closeSafely(driverCache) } } } } case class ValidateSaveMode(saveMode: String) extends Validation { override def validate(): Unit = { ValidationUtil.isTrue( Neo4jOptions.SUPPORTED_SAVE_MODES.contains(SaveMode.valueOf(saveMode)), s"""Unsupported SaveMode. |You provided $saveMode, supported are: |${Neo4jOptions.SUPPORTED_SAVE_MODES.mkString(",")} |""".stripMargin ) } } case class ValidateWrite( neo4j: Neo4j, neo4jOptions: Neo4jOptions, jobId: String, saveMode: SaveMode, customValidation: Neo4jOptions => Unit = _ => () ) extends Validation { override def validate(): Unit = { ValidationUtil.isFalse( neo4jOptions.session.accessMode == AccessMode.READ, s"Mode READ not supported for Data Source writer" ) val cache = new DriverCache(neo4jOptions.connection) val schemaService = new SchemaService(neo4j, neo4jOptions, cache) try { ValidateConnection(neo4jOptions, jobId).validate() ValidateNeo4jOptionsConsistency(neo4jOptions).validate() ValidateSchemaMetadataWrite(neo4jOptions, saveMode).validate() neo4jOptions.query.queryType match { case QueryType.QUERY => { val error = schemaService.validateQuery( s"""WITH {} AS ${Neo4jQueryStrategy.VARIABLE_EVENT}, [] as ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT} |${neo4jOptions.query.value} |""".stripMargin, org.neo4j.driver.summary.QueryType.WRITE_ONLY, org.neo4j.driver.summary.QueryType.READ_WRITE ) ValidationUtil.isTrue(error.isEmpty, error) neo4jOptions.schemaMetadata.optimizationType match { case OptimizationType.NONE => // are valid case _ => ValidationUtil.isNotValid( s"""With Query Type ${neo4jOptions.query.queryType} you can |only use `${Neo4jOptions.SCHEMA_OPTIMIZATION_TYPE}` |`${OptimizationType.NONE}` |""".stripMargin ) } } case QueryType.LABELS => { saveMode match { case SaveMode.Overwrite => { ValidationUtil.isNotEmpty( neo4jOptions.nodeMetadata.nodeKeys, s"${Neo4jOptions.NODE_KEYS} is required when Save Mode is Overwrite" ) } case _ => () } } case QueryType.RELATIONSHIP => { ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.target.labels, s"${Neo4jOptions.RELATIONSHIP_SOURCE_LABELS} is required when Save Mode is Overwrite" ) ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.target.labels, s"${Neo4jOptions.RELATIONSHIP_TARGET_LABELS} is required when Save Mode is Overwrite" ) } } neo4jOptions.script.foreach(query => ValidationUtil.isTrue( schemaService.isValidQuery(query), s"The following query inside the `${Neo4jOptions.SCRIPT}` is not valid, please check the syntax: $query" ) ) customValidation(neo4jOptions) } finally { schemaService.close() cache.close() } } } case class ValidateRead(neo4j: Neo4j, neo4jOptions: Neo4jOptions, jobId: String) extends Validation { override def validate(): Unit = { val cache = new DriverCache(neo4jOptions.connection) val schemaService = new SchemaService(neo4j, neo4jOptions, cache) try { ValidateConnection(neo4jOptions, jobId).validate() ValidateNeo4jOptionsConsistency(neo4jOptions).validate() neo4jOptions.query.queryType match { case QueryType.LABELS => { ValidationUtil.isNotEmpty( neo4jOptions.nodeMetadata.labels, s"You need to set the ${QueryType.LABELS.toString.toLowerCase} option" ) } case QueryType.RELATIONSHIP => { ValidationUtil.isNotBlank( neo4jOptions.relationshipMetadata.relationshipType, s"You need to set the ${QueryType.RELATIONSHIP.toString.toLowerCase} option" ) ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.source.labels, s"You need to set the ${Neo4jOptions.RELATIONSHIP_SOURCE_LABELS} option" ) ValidationUtil.isNotEmpty( neo4jOptions.relationshipMetadata.target.labels, s"You need to set the ${Neo4jOptions.RELATIONSHIP_TARGET_LABELS} option" ) } case QueryType.QUERY => { ValidationUtil.isFalse( neo4jOptions.query.value.matches("(?si).*(LIMIT \\d+|SKIP ?\\d+)\\s*\\z"), "SKIP/LIMIT are not allowed at the end of the query" ) val queryError = schemaService.validateQuery( s"""WITH [] as ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT} |${neo4jOptions.query.value} |""".stripMargin, org.neo4j.driver.summary.QueryType.READ_ONLY ) ValidationUtil.isTrue(queryError.isEmpty, queryError) if (neo4jOptions.queryMetadata.queryCount.nonEmpty) { if (!Neo4jUtil.isLong(neo4jOptions.queryMetadata.queryCount)) { val queryCountError = schemaService.validateQueryCount(neo4jOptions.queryMetadata.queryCount) ValidationUtil.isTrue(queryCountError.isEmpty, queryCountError) } } } case QueryType.GDS => { ValidationUtil.isFalse( neo4jOptions.query.value.contains(".mutate") || neo4jOptions.query.value.contains(".write"), "You cannot execute GDS mutate or write procedure in a read query" ) ValidationUtil.isTrue( schemaService.isGdsProcedure(neo4jOptions.query.value), s"GDS procedure ${neo4jOptions.query.value} does not exist" ) ValidationUtil.isTrue(neo4jOptions.partitions == 1, "For GDS queries we support only one partition") Validations.validate(ValidateGdsMetadata(neo4jOptions.gdsMetadata)) } } val scriptErrors = neo4jOptions.script .map(schemaService.validateQuery(_)) .filter(_.nonEmpty) .mkString("\n") ValidationUtil.isTrue( scriptErrors.isEmpty, s""" |The following queries inside the `${Neo4jOptions.SCRIPT}` are not valid, |please check their syntax: |$scriptErrors |""".stripMargin ) } finally { schemaService.close() cache.close() } } } case class ValidateReadNotStreaming(neo4jOptions: Neo4jOptions, jobId: String) extends Validation { override def validate(): Unit = { ValidationUtil.isBlank( neo4jOptions.streamingOptions.propertyName, s"You don't need to set the `${Neo4jOptions.STREAMING_PROPERTY_NAME}` option" ) } } /** * df: this method checks for inconsistencies between provided options. * Ex: if we use the QueryType.LABELS, we will ignore any relationship options. * * Plus it throws an exception if no QueryType is provided. */ case class ValidateNeo4jOptionsConsistency(neo4jOptions: Neo4jOptions) extends Validation { override def validate(): Unit = { if (neo4jOptions.query.value.isEmpty) { val reqTypes = QueryType.values.map(qt => s"`${qt.toString}`").mkString(", ") throw new IllegalArgumentException(s"No valid option found. One of $reqTypes is required") } neo4jOptions.query.queryType match { case QueryType.LABELS => ignoreQueryMetadata(QueryType.LABELS) ignoreRelMetadata(QueryType.LABELS) ignoreGdsMetadata(QueryType.LABELS) case QueryType.RELATIONSHIP => ignoreQueryMetadata(QueryType.RELATIONSHIP) ignoreNodeMetadata(QueryType.RELATIONSHIP) ignoreGdsMetadata(QueryType.RELATIONSHIP) case QueryType.QUERY => { ignoreNodeMetadata(QueryType.QUERY) ignoreRelMetadata(QueryType.QUERY) ignoreGdsMetadata(QueryType.QUERY) } case QueryType.GDS => ignoreQueryMetadata(QueryType.GDS) ignoreNodeMetadata(QueryType.GDS) ignoreRelMetadata(QueryType.GDS) } } private def ignoreGdsMetadata(queryType: QueryType.Value): Unit = { if (!neo4jOptions.gdsMetadata.parameters.isEmpty) { ignoreOption(Neo4jOptions.QUERY_COUNT, queryType.toString.toLowerCase(Locale.ENGLISH)) } } private def ignoreQueryMetadata(queryType: QueryType.Value): Unit = { if (neo4jOptions.queryMetadata.queryCount.nonEmpty) { ignoreOption(Neo4jOptions.QUERY_COUNT, queryType.toString.toLowerCase(Locale.ENGLISH)) } } private def ignoreRelMetadata(queryType: QueryType.Value): Unit = { val optName = queryTypeAsOptionString(queryType) if (neo4jOptions.relationshipMetadata.source.labels.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_SOURCE_LABELS, optName) } if (neo4jOptions.relationshipMetadata.source.properties.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_SOURCE_NODE_PROPS, optName) } if (neo4jOptions.relationshipMetadata.source.nodeKeys.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_SOURCE_NODE_KEYS, optName) } if (neo4jOptions.relationshipMetadata.target.labels.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_TARGET_LABELS, optName) } if (neo4jOptions.relationshipMetadata.target.properties.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_TARGET_NODE_PROPS, optName) } if (neo4jOptions.relationshipMetadata.target.nodeKeys.nonEmpty) { ignoreOption(Neo4jOptions.RELATIONSHIP_TARGET_NODE_KEYS, optName) } } private def queryTypeAsOptionString(queryType: util.QueryType.Value): String = queryType.toString.toLowerCase(Locale.ENGLISH) private def ignoreNodeMetadata(queryType: QueryType.Value): Unit = { val optName = queryTypeAsOptionString(queryType) if (neo4jOptions.nodeMetadata.nodeKeys.nonEmpty) { ignoreOption(Neo4jOptions.NODE_KEYS, optName) } if (neo4jOptions.nodeMetadata.properties.nonEmpty) { ignoreOption(Neo4jOptions.NODE_PROPS, optName) } } } case class ValidateGdsMetadata(neo4jGdsMetadata: Neo4jGdsMetadata) extends Validation { override def validate(): Unit = { val hasGraphName = neo4jGdsMetadata.parameters.get("graphName") != null val hasGraphNameOrConfiguration = neo4jGdsMetadata.parameters.get("graphNameOrConfiguration") != null ValidationUtil.isTrue( hasGraphName || hasGraphNameOrConfiguration, "One between gds.graphName or gds.graphNameOrConfiguration is required" ) } } case class ValidateReadStreaming(neo4j: Neo4j, neo4jOptions: Neo4jOptions, jobId: String) extends Validation { override def validate(): Unit = { val cache = new DriverCache(neo4jOptions.connection) val schemaService = new SchemaService(neo4j, neo4jOptions, cache) try { ValidationUtil.isTrue( neo4jOptions.partitions == 1, "For Spark Structured Streaming we support only one partition" ) neo4jOptions.query.queryType match { case QueryType.QUERY => { ValidationUtil.isTrue( schemaService.isValidQuery(neo4jOptions.streamingOptions.queryOffset, summary.QueryType.READ_ONLY), """ |Please set `streaming.query.offset` with a valid Cypher READ_ONLY query |that returns a long value i.e. |MATCH (p:MyLabel) |RETURN max(p.timestamp) |""".stripMargin ) } case _ => } } finally { schemaService.close() cache.close() } } } ================================================ FILE: common/src/main/scala/org/neo4j/spark/writer/BaseDataWriter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.driver.Session import org.neo4j.driver.Transaction import org.neo4j.driver.Values import org.neo4j.driver.exceptions.ServiceUnavailableException import org.neo4j.spark.service._ import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil.closeSafely import org.neo4j.spark.util.Neo4jUtil.isRetryableException import java.io.Closeable import java.time.Duration import java.util import java.util.concurrent.CountDownLatch import java.util.concurrent.locks.LockSupport import scala.annotation.tailrec import scala.collection.JavaConverters._ abstract class BaseDataWriter( neo4j: Neo4j, jobId: String, partitionId: Int, structType: StructType, saveMode: SaveMode, options: Neo4jOptions, scriptResult: java.util.List[java.util.Map[String, AnyRef]] ) extends Logging with Closeable with DataWriter[InternalRow] { private val STOPPED_THREAD_EXCEPTION_MESSAGE = "Connection to the database terminated. Thread interrupted while committing the transaction" private val driverCache: DriverCache = new DriverCache(options.connection) private var transaction: Transaction = _ private var session: Session = _ private val mappingService = new MappingService(new Neo4jWriteMappingStrategy(options), options) private val batch: util.List[java.util.Map[String, Object]] = new util.ArrayList[util.Map[String, Object]]() private val retries = new CountDownLatch(options.transactionSettings.retries) private val query: String = new Neo4jQueryService(options, new Neo4jQueryWriteStrategy(neo4j, saveMode)).createQuery() private val metrics = DataWriterMetrics() private var skipped = 0 def write(record: InternalRow): Unit = { val mapped = mappingService.convert(record, structType) mapped match { case Some(m) => batch.add(m) case None => skipped += 1 } if (batch.size() == options.transactionSettings.batchSize) { writeBatch() } } @tailrec private def writeBatch(): Unit = { try { if (session == null || !session.isOpen) { session = driverCache.getOrCreate().session(options.session.toNeo4jSession()) } if (transaction == null || !transaction.isOpen) { transaction = session.beginTransaction(options.toNeo4jTransactionConfig) } log.info( s"""Writing a batch of ${batch.size()} elements to Neo4j, |for jobId=$jobId and partitionId=$partitionId |with query: $query |""".stripMargin ) val result = transaction.run( query, Values.value(Map[String, AnyRef]( Neo4jQueryStrategy.VARIABLE_EVENTS -> batch, Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT -> scriptResult ).asJava) ) val summary = result.consume() val counters = summary.counters() if (log.isDebugEnabled) { log.debug( s"""Batch saved into Neo4j data with: | - nodes created: ${counters.nodesCreated()} | - nodes deleted: ${counters.nodesDeleted()} | - relationships created: ${counters.relationshipsCreated()} | - relationships deleted: ${counters.relationshipsDeleted()} | - properties set: ${counters.propertiesSet()} | - labels added: ${counters.labelsAdded()} | - labels removed: ${counters.labelsRemoved()} |""".stripMargin ) } transaction.commit() if (skipped > 0) { log.info(s"Skipped $skipped rows that contained null values in one of their key property values.") skipped = 0 } // update metrics metrics.applyCounters(batch.size(), counters) closeSafely(transaction) batch.clear() } catch { case e: Throwable => if (options.transactionSettings.shouldFailOn(e)) { log.error("unable to write batch due to explicitly configured failure condition", e) throw e } if (isRetryableException(e) && retries.getCount > 0) { retries.countDown() log.info( s"encountered a transient exception while writing batch, retrying ${options.transactionSettings.retries - retries.getCount} time", e ) close() LockSupport.parkNanos(Duration.ofMillis(options.transactionSettings.retryTimeout).toNanos) writeBatch() } else { logAndThrowException(e) } } } /** * df: we check if the thrown exception is STOPPED_THREAD_EXCEPTION. This is the * exception that is thrown when the streaming query is interrupted, we don't want to cause * any error in this case. The transaction are rolled back automatically. */ private def logAndThrowException(e: Throwable): Unit = { if (e.isInstanceOf[ServiceUnavailableException] && e.getMessage == STOPPED_THREAD_EXCEPTION_MESSAGE) { logWarning(e.getMessage) } else { logError("unable to write batch", e) } throw e } def commit(): Null = { writeBatch() close() null } def abort(): Unit = { if (transaction != null && transaction.isOpen) { try { transaction.rollback() } catch { case e: Throwable => log.warn("Cannot rollback the transaction because of the following exception", e) } } close() } def close(): Unit = { closeSafely(transaction, log) closeSafely(session, log) } override def currentMetricsValues(): Array[CustomTaskMetric] = metrics.metricValues() } ================================================ FILE: common/src/main/scala/org/neo4j/spark/writer/DataWriterMetrics.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.metric.CustomSumMetric import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.neo4j.driver.summary.SummaryCounters import org.neo4j.spark.writer.DataWriterMetrics.LABELS_ADDED import org.neo4j.spark.writer.DataWriterMetrics.LABELS_REMOVED import org.neo4j.spark.writer.DataWriterMetrics.NODES_CREATED import org.neo4j.spark.writer.DataWriterMetrics.NODES_DELETED import org.neo4j.spark.writer.DataWriterMetrics.PROPERTIES_SET import org.neo4j.spark.writer.DataWriterMetrics.RECORDS_WRITTEN import org.neo4j.spark.writer.DataWriterMetrics.RELATIONSHIPS_CREATED import org.neo4j.spark.writer.DataWriterMetrics.RELATIONSHIPS_DELETED import java.util.concurrent.atomic.AtomicLong case class DataWriterMetric(name: String, value: Long) extends CustomTaskMetric {} class DataWriterMetrics private ( recordsProcessed: AtomicLong, nodesCreated: AtomicLong, nodesDeleted: AtomicLong, relationshipsCreated: AtomicLong, relationshipsDeleted: AtomicLong, propertiesSet: AtomicLong, labelsAdded: AtomicLong, labelsRemoved: AtomicLong ) { def applyCounters(recordsWritten: Long, counters: SummaryCounters): Unit = { this.recordsProcessed.addAndGet(recordsWritten) this.nodesCreated.addAndGet(counters.nodesCreated()) this.nodesDeleted.addAndGet(counters.nodesDeleted()) this.relationshipsCreated.addAndGet(counters.relationshipsCreated()) this.relationshipsDeleted.addAndGet(counters.relationshipsDeleted()) this.propertiesSet.addAndGet(counters.propertiesSet()) this.labelsAdded.addAndGet(counters.labelsAdded()) this.labelsRemoved.addAndGet(counters.labelsRemoved()) } def metricValues(): Array[CustomTaskMetric] = { List[CustomTaskMetric]( DataWriterMetric(RECORDS_WRITTEN, recordsProcessed.longValue()), DataWriterMetric(NODES_CREATED, nodesCreated.longValue()), DataWriterMetric(NODES_DELETED, nodesDeleted.longValue()), DataWriterMetric(RELATIONSHIPS_CREATED, relationshipsCreated.longValue()), DataWriterMetric(RELATIONSHIPS_DELETED, relationshipsDeleted.longValue()), DataWriterMetric(PROPERTIES_SET, propertiesSet.longValue()), DataWriterMetric(LABELS_ADDED, labelsAdded.longValue()), DataWriterMetric(LABELS_REMOVED, labelsRemoved.longValue()) ).toArray } } object DataWriterMetrics { final val RECORDS_WRITTEN = "neo4jMetrics.recordsWritten" final val RECORDS_WRITTEN_DESCRIPTION = "number of records written" final val NODES_CREATED = "neo4jMetrics.nodesCreated" final val NODES_CREATED_DESCRIPTION = "number of nodes created" final val NODES_DELETED = "neo4jMetrics.nodesDeleted" final val NODES_DELETED_DESCRIPTION = "number of nodes deleted" final val RELATIONSHIPS_CREATED = "neo4jMetrics.relationshipsCreated" final val RELATIONSHIPS_CREATED_DESCRIPTION = "number of relationships created" final val RELATIONSHIPS_DELETED = "neo4jMetrics.relationshipsDeleted" final val RELATIONSHIPS_DELETED_DESCRIPTION = "number of relationships deleted" final val PROPERTIES_SET = "neo4jMetrics.propertiesSet" final val PROPERTIES_SET_DESCRIPTION = "number of properties set" final val LABELS_ADDED = "neo4jMetrics.labelsAdded" final val LABELS_ADDED_DESCRIPTION = "number of labels added" final val LABELS_REMOVED = "neo4jMetrics.labelsRemoved" final val LABELS_REMOVED_DESCRIPTION = "number of labels removed" def apply(): DataWriterMetrics = { new DataWriterMetrics( new AtomicLong(0), new AtomicLong(0), new AtomicLong(0), new AtomicLong(0), new AtomicLong(0), new AtomicLong(0), new AtomicLong(0), new AtomicLong(0) ) } def metricDeclarations(): Array[CustomMetric] = { List[CustomMetric]( new RecordsWrittenMetric, new NodesCreatedMetric, new NodesDeletedMetric, new RelationshipsCreatedMetric, new RelationshipsDeletedMetric, new PropertiesSetMetric, new LabelsAddedMetric, new LabelsRemovedMetric ).toArray } } class RecordsWrittenMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.RECORDS_WRITTEN override def description(): String = DataWriterMetrics.RECORDS_WRITTEN_DESCRIPTION } class NodesCreatedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.NODES_CREATED override def description(): String = DataWriterMetrics.NODES_CREATED_DESCRIPTION } class NodesDeletedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.NODES_DELETED override def description(): String = DataWriterMetrics.NODES_DELETED_DESCRIPTION } class RelationshipsCreatedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.RELATIONSHIPS_CREATED override def description(): String = DataWriterMetrics.RELATIONSHIPS_CREATED_DESCRIPTION } class RelationshipsDeletedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.RELATIONSHIPS_DELETED override def description(): String = DataWriterMetrics.RELATIONSHIPS_DELETED_DESCRIPTION } class PropertiesSetMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.PROPERTIES_SET override def description(): String = DataWriterMetrics.PROPERTIES_SET_DESCRIPTION } class LabelsAddedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.LABELS_ADDED override def description(): String = DataWriterMetrics.LABELS_ADDED_DESCRIPTION } class LabelsRemovedMetric extends CustomSumMetric { override def name(): String = DataWriterMetrics.LABELS_REMOVED override def description(): String = DataWriterMetrics.LABELS_REMOVED_DESCRIPTION } ================================================ FILE: common/src/test/scala/org/neo4j/spark/CommonTestSuiteIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.runner.RunWith import org.junit.runners.Suite import org.neo4j.spark.service.SchemaServiceTSE @RunWith(classOf[Suite]) @Suite.SuiteClasses(Array( classOf[SchemaServiceTSE] )) class CommonTestSuiteIT extends SparkConnectorScalaSuiteIT {} ================================================ FILE: common/src/test/scala/org/neo4j/spark/CommonTestSuiteWithApocIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.runner.RunWith import org.junit.runners.Suite import org.neo4j.spark.service.SchemaServiceWithApocTSE @RunWith(classOf[Suite]) @Suite.SuiteClasses(Array( classOf[SchemaServiceWithApocTSE] )) class CommonTestSuiteWithApocIT extends SparkConnectorScalaSuiteWithApocIT {} ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.junit.Test import org.junit.runner.RunWith import org.mockito.ArgumentMatchers import org.mockito.ArgumentMatchers._ import org.mockito.Mockito.times import org.neo4j.driver.AuthTokens import org.neo4j.driver.Config import org.neo4j.driver.GraphDatabase import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.powermock.api.mockito.PowerMockito import org.powermock.core.classloader.annotations.PowerMockIgnore import org.powermock.core.classloader.annotations.PrepareForTest import org.powermock.modules.junit4.PowerMockRunner import org.testcontainers.shaded.com.google.common.io.BaseEncoding import java.net.URI import java.util @PrepareForTest(Array(classOf[GraphDatabase])) @RunWith(classOf[PowerMockRunner]) @PowerMockIgnore(value = Array("javax.management.*")) class AuthenticationTest { @Test def testLdapConnectionToken(): Unit = { val token = BaseEncoding.base64.encode("user:password".getBytes) val options = new util.HashMap[String, String] options.put("url", "bolt://localhost:7687") options.put("authentication.type", "custom") options.put("authentication.custom.credentials", token) options.put("labels", "Person") val neo4jOptions = new Neo4jOptions(options) val neo4jDriverOptions = neo4jOptions.connection val driverCache = new DriverCache(neo4jDriverOptions) PowerMockito.spy(classOf[GraphDatabase]) driverCache.getOrCreate() PowerMockito.verifyStatic(classOf[GraphDatabase], times(1)) GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")), any(classOf[Config])) } @Test def testBearerAuthToken(): Unit = { val token = BaseEncoding.base64.encode("user:password".getBytes) val options = new util.HashMap[String, String] options.put("url", "bolt://localhost:7687") options.put("authentication.type", "bearer") options.put("authentication.bearer.token", token) val neo4jOptions = new Neo4jOptions(options) val neo4jDriverOptions = neo4jOptions.connection val driverCache = new DriverCache(neo4jDriverOptions) PowerMockito.spy(classOf[GraphDatabase]) driverCache.getOrCreate() PowerMockito.verifyStatic(classOf[GraphDatabase], times(1)) GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any()) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/Neo4jQueryServiceIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.spark.sql.connector.expressions.aggregate.Count import org.apache.spark.sql.connector.expressions.aggregate.Max import org.apache.spark.sql.connector.expressions.aggregate.Min import org.apache.spark.sql.connector.expressions.aggregate.Sum import org.junit.After import org.junit.FixMethodOrder import org.junit.Test import org.junit.runners.MethodSorters import org.neo4j.spark.SparkConnectorScalaSuiteWithGdsBase import org.neo4j.spark.SparkConnectorScalaSuiteWithGdsBase.neo4j import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.DummyNamedReference import org.neo4j.spark.util.Neo4jOptions import org.scalatest.matchers.must.Matchers.convertToAnyMustWrapper import org.scalatest.matchers.must.Matchers.endWith import scala.language.postfixOps @FixMethodOrder(MethodSorters.JVM) class Neo4jQueryServiceIT extends SparkConnectorScalaSuiteWithGdsBase { @After def cleanUp(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) new DriverCache(neo4jOptions.connection).close() } @Test def testShouldDoAggregationOnGDS(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) options.put("gds", "gds.pageRank.stream") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val field = new DummyNamedReference("score") val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty, PartitionPagination.EMPTY, List( "nodeId", "MAX(score)", "MIN(score)", "COUNT(score)", "COUNT(DISTINCT score)", "SUM(score)", "SUM(DISTINCT score)" ), Array( new Max(field), new Min(field), new Sum(field, false), new Count(field, false), new Count(field, true), new Sum(field, false), new Sum(field, true) ) ) ).createQuery() query must endWith( """CALL gds.pageRank.stream($graphName) |YIELD nodeId, score |RETURN nodeId AS nodeId, max(score) AS `MAX(score)`, min(score) AS `MIN(score)`, count(score) AS `COUNT(score)`, count(DISTINCT score) AS `COUNT(DISTINCT score)`, sum(score) AS `SUM(score)`, sum(DISTINCT score) AS `SUM(DISTINCT score)`""" .stripMargin .replaceAll("\n", " ") ) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/Neo4jQueryServiceTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import junitparams.JUnitParamsRunner import junitparams.Parameters import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.NullOrdering import org.apache.spark.sql.connector.expressions.SortDirection import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Count import org.apache.spark.sql.connector.expressions.aggregate.Max import org.apache.spark.sql.connector.expressions.aggregate.Min import org.apache.spark.sql.connector.expressions.aggregate.Sum import org.apache.spark.sql.sources._ import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDeploymentType.SELF_MANAGED import org.neo4j.caniuse.Neo4jEdition import org.neo4j.caniuse.Neo4jEdition.COMMUNITY import org.neo4j.caniuse.Neo4jEdition.ENTERPRISE import org.neo4j.caniuse.Neo4jVersion import org.neo4j.spark.config.TopN import org.neo4j.spark.util.DummyNamedReference import org.neo4j.spark.util.Neo4jImplicits.CypherImplicits import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.QueryType import scala.collection.immutable.HashMap @RunWith(classOf[JUnitParamsRunner]) class Neo4jQueryServiceTest { @Test @Parameters(method = "versions_and_prefixes") def testNodeOneLabel(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j)).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN n", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeMultipleLabels(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, ":Person:Player:Midfield") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j)).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`:`Player`:`Midfield`) RETURN n", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeMultipleLabelsWithPartitions(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, ":Person:Player:Midfield") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, partitionPagination = PartitionPagination(0, 0, TopN(100)) ) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`:`Player`:`Midfield`) RETURN n LIMIT 100", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeOneLabelWithOneSelectedColumn(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy(neo4j, Array.empty[Filter], PartitionPagination.EMPTY, Seq("name")) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN n.name AS name", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeOneLabelWithMultipleColumnSelected(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy(neo4j, Array.empty[Filter], PartitionPagination.EMPTY, List("name", "bornDate")) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN n.name AS name, n.bornDate AS bornDate", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeOneLabelWithInternalIdSelected(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy(neo4j, Array.empty[Filter], PartitionPagination.EMPTY, List("")) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN id(n) AS ``", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeFilterEqualTo(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( EqualTo("name", "John Doe") ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val paramName = "$" + "name".toParameterName("John Doe") assertEquals(s"${prefix}MATCH (n:`Person`) WHERE n.name = $paramName RETURN n", query) } @Test @Parameters(method = "versions_and_prefixes") def testNodeFilterEqualNullSafe(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( EqualNullSafe("name", "John Doe"), EqualTo("age", 36) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val nameParameterName = "$" + "name".toParameterName("John Doe") val ageParameterName = "$" + "age".toParameterName(36) assertEquals( s"""${prefix}MATCH (n:`Person`) | WHERE (((n.name IS NULL AND $nameParameterName IS NULL) | OR n.name = $nameParameterName) AND n.age = $ageParameterName) | RETURN n""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testNodeFilterEqualNullSafeWithNullValue(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( EqualNullSafe("name", null), EqualTo("age", 36) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val nameParameterName = "$" + "name".toParameterName(null) val ageParameterName = "$" + "age".toParameterName(36) assertEquals( s"${prefix}MATCH (n:`Person`) WHERE (((n.name IS NULL AND $nameParameterName IS NULL) OR n.name = $nameParameterName) AND n.age = $ageParameterName) RETURN n", query ) } @Test @Parameters(method = "versions_and_prefixes") def testNodeFilterStartsEndsWith(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( StringStartsWith("name", "Person Name"), StringEndsWith("name", "Person Surname") ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val nameOneParameterName = "$" + "name".toParameterName("Person Name") val nameTwoParameterName = "$" + "name".toParameterName("Person Surname") assertEquals( s"""${prefix}MATCH (n:`Person`) | WHERE (n.name STARTS WITH $nameOneParameterName | AND n.name ENDS WITH $nameTwoParameterName) | RETURN n""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipWithOneColumnSelected(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, List("source.name") ) ).createQuery() assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + "MATCH (source)-[rel:`KNOWS`]->(target) RETURN source.name AS `source.name`", query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipWithMoreColumnSelected(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, List("source.name", "") ) ).createQuery() assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + "MATCH (source)-[rel:`KNOWS`]->(target) RETURN source.name AS `source.name`, id(source) AS ``", query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipWithMoreColumnSelectedWithPartitions(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination(0, 0, TopN(limit = 100)), List("source.name", "") ) ).createQuery() assertEquals( s"""${prefix}MATCH (source:`Person`) |MATCH (target:`Person`) |MATCH (source)-[rel:`KNOWS`]->(target) |RETURN source.name AS `source.name`, id(source) AS `` |LIMIT 100""" .stripMargin .replace(System.lineSeparator(), " "), query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipWithMoreColumnsSelected(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, List("source.name", "source.id", "rel.someprops", "target.date") ) ).createQuery() assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + "MATCH (source)-[rel:`KNOWS`]->(target) RETURN source.name AS `source.name`, source.id AS `source.id`, rel.someprops AS `rel.someprops`, target.date AS `target.date`", query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipFilterEqualTo(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( EqualTo("source.name", "John Doe") ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val parameterName = "$" + "source.name".toParameterName("John Doe") assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + s"MATCH (source)-[rel:`KNOWS`]->(target) WHERE source.name = $parameterName RETURN rel, source AS source, target AS target", query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipFilterNotEqualTo(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( Or(EqualTo("source.name", "John Doe"), EqualTo("target.name", "John Doe")) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val paramOneName = "$" + "source.name".toParameterName("John Doe") val paramTwoName = "$" + "target.name".toParameterName("John Doe") assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + s"MATCH (source)-[rel:`KNOWS`]->(target) WHERE (source.name = $paramOneName OR target.name = $paramTwoName) RETURN rel, source AS source, target AS target", query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipAndFilterEqualTo(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "true") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( EqualTo("source.id", "14"), EqualTo("target.id", "16") ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val sourceIdParameterName = "$" + "source.id".toParameterName(14) val targetIdParameterName = "$" + "target.id".toParameterName(16) assertEquals( s"""${prefix}MATCH (source:`Person`) | MATCH (target:`Person`) | MATCH (source)-[rel:`KNOWS`]->(target) | WHERE (source.id = $sourceIdParameterName AND target.id = $targetIdParameterName) | RETURN rel, source AS source, target AS target |""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testComplexNodeConditions(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( Or(EqualTo("name", "John Doe"), EqualTo("name", "John Scofield")), Or(EqualTo("age", 15), GreaterThanOrEqual("age", 18)), Or(Not(EqualTo("age", 22)), Not(LessThan("age", 11))) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val parameterNames: Map[String, String] = HashMap( "name_1" -> "$".concat("name".toParameterName("John Doe")), "name_2" -> "$".concat("name".toParameterName("John Scofield")), "age_1" -> "$".concat("age".toParameterName(15)), "age_2" -> "$".concat("age".toParameterName(18)), "age_3" -> "$".concat("age".toParameterName(22)), "age_4" -> "$".concat("age".toParameterName(11)) ) assertEquals( s"""${prefix}MATCH (n:`Person`) | WHERE (((n.name = ${parameterNames("name_1")} OR n.name = ${parameterNames("name_2")}) | AND (n.age = ${parameterNames("age_1")} OR n.age >= ${parameterNames("age_2")})) | AND (NOT (n.age = ${parameterNames("age_3")}) OR NOT (n.age < ${parameterNames("age_4")}))) | RETURN n""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipFilterComplexConditionsNoMap(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person:Customer") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( Or( Or(EqualTo("source.name", "John Doe"), EqualTo("target.name", "John Doraemon")), EqualTo("source.name", "Jane Doe") ), Or(EqualTo("target.age", 34), EqualTo("target.age", 18)), EqualTo("rel.score", 12) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val parameterNames = Map( "source.name_1" -> "$".concat("source.name".toParameterName("John Doe")), "target.name_1" -> "$".concat("target.name".toParameterName("John Doraemon")), "source.name_2" -> "$".concat("source.name".toParameterName("Jane Doe")), "target.age_1" -> "$".concat("target.age".toParameterName(34)), "target.age_2" -> "$".concat("target.age".toParameterName(18)), "rel.score" -> "$".concat("rel.score".toParameterName(12)) ) assertEquals( s"""${prefix}MATCH (source:`Person`) | MATCH (target:`Person`:`Customer`) | MATCH (source)-[rel:`KNOWS`]->(target) | WHERE ((source.name = ${parameterNames("source.name_1")} OR target.name = ${ parameterNames( "target.name_1" ) } OR source.name = ${parameterNames("source.name_2")}) | AND (target.age = ${parameterNames("target.age_1")} OR target.age = ${parameterNames("target.age_2")}) | AND rel.score = ${parameterNames("rel.score")}) | RETURN rel, source AS source, target AS target""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testRelationshipFilterComplexConditionsWithMap(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "true") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person:Customer") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val filters: Array[Filter] = Array[Filter]( Or( Or(EqualTo("source.name", "John Doe"), EqualTo("target.name", "John Doraemon")), EqualTo("source.name", "Jane Doe") ), Or(EqualTo("target.age", 34), EqualTo("target.age", 18)), EqualTo("rel.score", 12) ) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryReadStrategy(neo4j, filters)).createQuery() val parameterNames = Map( "source.name_1" -> "$".concat("source.name".toParameterName("John Doe")), "target.name_1" -> "$".concat("target.name".toParameterName("John Doraemon")), "source.name_2" -> "$".concat("source.name".toParameterName("Jane Doe")), "target.age_1" -> "$".concat("target.age".toParameterName(34)), "target.age_2" -> "$".concat("target.age".toParameterName(18)), "rel.score" -> "$".concat("rel.score".toParameterName(12)) ) assertEquals( s"""${prefix}MATCH (source:`Person`) | MATCH (target:`Person`:`Customer`) | MATCH (source)-[rel:`KNOWS`]->(target) | WHERE ((source.name = ${parameterNames("source.name_1")} OR target.name = ${ parameterNames( "target.name_1" ) } OR source.name = ${parameterNames("source.name_2")}) | AND (target.age = ${parameterNames("target.age_1")} OR target.age = ${parameterNames("target.age_2")}) | AND rel.score = ${parameterNames("rel.score")}) | RETURN rel, source AS source, target AS target |""".stripMargin.replaceAll("\n", ""), query ) } @Test @Parameters(method = "versions_and_prefixes") def testCompoundKeysForNodes(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("labels", "Location") options.put("node.keys", "LocationName:name,LocationType:type,FeatureID:featureId") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryWriteStrategy(neo4j, SaveMode.Overwrite)).createQuery() assertEquals( s"""${prefix}UNWIND $$events AS event |MERGE (node:Location {name: event.keys.name, type: event.keys.type, featureId: event.keys.featureId}) |SET node += event.properties |""".stripMargin, query ) } @Test @Parameters(method = "versions_and_prefixes") def testCompoundKeysForRelationship(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "BOUGHT") options.put("relationship.source.labels", "Person") options.put("relationship.source.node.keys", "FirstName:name,LastName:lastName") options.put("relationship.target.labels", "Product") options.put("relationship.target.node.keys", "ProductPrice:price,ProductId:id") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryWriteStrategy(neo4j, SaveMode.Overwrite)).createQuery() assertEquals( s"""${prefix}UNWIND $$events AS event |MATCH (source:Person {name: event.source.keys.name, lastName: event.source.keys.lastName}) |MATCH (target:Product {price: event.target.keys.price, id: event.target.keys.id}) |MERGE (source)-[rel:BOUGHT]->(target) |SET rel += event.rel.properties |""".stripMargin, query.stripMargin ) } @Test @Parameters(method = "versions_and_prefixes") def testCompoundKeysForRelationshipMergeMatch(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "BOUGHT") options.put("relationship.source.labels", "Person") options.put("relationship.source.node.keys", "FirstName:name,LastName:lastName") options.put("relationship.source.save.mode", "Overwrite") options.put("relationship.target.labels", "Product") options.put("relationship.target.node.keys", "ProductPrice:price,ProductId:id") options.put("relationship.target.save.mode", "match") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService(neo4jOptions, new Neo4jQueryWriteStrategy(neo4j, SaveMode.Overwrite)).createQuery() assertEquals( s"""${prefix}UNWIND $$events AS event |MERGE (source:Person {name: event.source.keys.name, lastName: event.source.keys.lastName}) SET source += event.source.properties |WITH source, event |MATCH (target:Product {price: event.target.keys.price, id: event.target.keys.id}) |MERGE (source)-[rel:BOUGHT]->(target) |SET rel += event.rel.properties |""".stripMargin, query.stripMargin ) } @Test @Parameters(method = "versions_and_prefixes") def testShouldDoSumAggregationOnLabels(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val ageField = new DummyNamedReference("age") var query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, Seq("name", "SUM(DISTINCT age)", "SUM(age)"), Array( new Sum(ageField, false), new Sum(ageField, true) ) ) ).createQuery() assertEquals( s"${prefix}MATCH (n:`Person`) RETURN n.name AS name, sum(DISTINCT n.age) AS `SUM(DISTINCT age)`, sum(n.age) AS `SUM(age)`", query ) val nameField = new DummyNamedReference("name") query = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, Seq("name", "COUNT(DISTINCT name)", "COUNT(name)"), Array( new Count(nameField, false), new Count(nameField, true) ) ) ).createQuery() assertEquals( s"${prefix}MATCH (n:`Person`) RETURN n.name AS name, count(DISTINCT n.name) AS `COUNT(DISTINCT name)`, count(n.name) AS `COUNT(name)`", query ) query = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination.EMPTY, Seq("name", "MAX(age)", "MIN(age)"), Array( new Max(ageField), new Min(ageField) ) ) ).createQuery() assertEquals( s"${prefix}MATCH (n:`Person`) RETURN n.name AS name, max(n.age) AS `MAX(age)`, min(n.age) AS `MIN(age)`", query ) } @Test @Parameters(method = "versions_and_prefixes") def testShouldDoSumAggregationOnRelationships(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "BOUGHT") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Product") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val targetPriceField = new DummyNamedReference("`target.price`") var query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty, PartitionPagination.EMPTY, List("source.fullName", "SUM(DISTINCT `target.price`)", "SUM(`target.price`)"), Array( new Sum(targetPriceField, false), new Sum(targetPriceField, true) ) ) ).createQuery() assertEquals( s"""${prefix}MATCH (source:`Person`) |MATCH (target:`Product`) |MATCH (source)-[rel:`BOUGHT`]->(target) |RETURN source.fullName AS `source.fullName`, sum(DISTINCT target.price) AS `SUM(DISTINCT ``target.price``)`, sum(target.price) AS `SUM(``target.price``)`""" .stripMargin .replaceAll("\n", " "), query ) val targetIdField = new DummyNamedReference("`target.id`") query = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty, PartitionPagination.EMPTY, List("source.fullName", "COUNT(DISTINCT `target.id`)", "COUNT(`target.id`)"), Array( new Count(targetIdField, false), new Count(targetIdField, true) ) ) ).createQuery() assertEquals( s"""${prefix}MATCH (source:`Person`) MATCH (target:`Product`) |MATCH (source)-[rel:`BOUGHT`]->(target) |RETURN source.fullName AS `source.fullName`, count(DISTINCT target.id) AS `COUNT(DISTINCT ``target.id``)`, count(target.id) AS `COUNT(``target.id``)`""" .stripMargin .replaceAll("\n", " "), query ) query = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty, PartitionPagination.EMPTY, List("source.fullName", "MAX(`target.price`)", "MIN(`target.price`)"), Array( new Max(targetPriceField), new Min(targetPriceField) ) ) ).createQuery() assertEquals( s"""${prefix}MATCH (source:`Person`) |MATCH (target:`Product`) |MATCH (source)-[rel:`BOUGHT`]->(target) |RETURN source.fullName AS `source.fullName`, max(target.price) AS `MAX(``target.price``)`, min(target.price) AS `MIN(``target.price``)`""" .stripMargin .replaceAll("\n", " "), query ) } @Test @Parameters(method = "versions_and_prefixes") def testTopNForLabels(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, partitionPagination = PartitionPagination( 0, 0, TopN( 42, Array(new SortOrder { override def expression(): Expression = new DummyNamedReference("name") override def direction(): SortDirection = SortDirection.ASCENDING override def nullOrdering(): NullOrdering = direction().defaultNullOrdering() }) ) ) ) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN n ORDER BY n.name ASC LIMIT 42", query) } @Test @Parameters(method = "versions_and_prefixes") def testTopNForLabelsWithRequiredColumn(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, requiredColumns = Array("name"), partitionPagination = PartitionPagination( 0, 0, TopN( 42, Array(new SortOrder { override def expression(): Expression = new DummyNamedReference("name") override def direction(): SortDirection = SortDirection.ASCENDING override def nullOrdering(): NullOrdering = direction().defaultNullOrdering() }) ) ) ) ).createQuery() assertEquals(s"${prefix}MATCH (n:`Person`) RETURN n.name AS name ORDER BY n.name ASC LIMIT 42", query) } @Test @Parameters(method = "versions_and_prefixes") def testTopNForRelationships(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination( 0, 0, TopN( 24, Array(new SortOrder { override def expression(): Expression = new DummyNamedReference("rel.since") override def direction(): SortDirection = SortDirection.DESCENDING override def nullOrdering(): NullOrdering = direction().defaultNullOrdering() }) ) ) ) ).createQuery() assertEquals( s"${prefix}MATCH (source:`Person`) " + "MATCH (target:`Person`) " + "MATCH (source)-[rel:`KNOWS`]->(target) RETURN rel, source AS source, target AS target " + "ORDER BY rel.since DESC LIMIT 24", query ) } @Test @Parameters(method = "versions_and_prefixes") def testTopNForRelationshipWithOneRequiredColumn(neo4j: Neo4j, prefix: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("relationship", "KNOWS") options.put("relationship.nodes.map", "false") options.put("relationship.source.labels", "Person") options.put("relationship.target.labels", "Person") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination( 0, 0, TopN( 24, Array(new SortOrder { override def expression(): Expression = new DummyNamedReference("rel.since") override def direction(): SortDirection = SortDirection.DESCENDING override def nullOrdering(): NullOrdering = direction().defaultNullOrdering() }) ) ), Array("source.name") ) ).createQuery() assertEquals( s"""${prefix}MATCH (source:`Person`) |MATCH (target:`Person`) |MATCH (source)-[rel:`KNOWS`]->(target) RETURN source.name AS `source.name` |ORDER BY rel.since DESC LIMIT 24""" .stripMargin .replaceAll("\n", " "), query ) } @Test @Parameters(method = "versions_and_prefixes") def testTopNForCustomQueryIgnoresAggregation(neo4j: Neo4j, ignored: String): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.QUERY.toString.toLowerCase, "MATCH (p:Person) RETURN p") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val query: String = new Neo4jQueryService( neo4jOptions, new Neo4jQueryReadStrategy( neo4j, Array.empty[Filter], PartitionPagination( 0, 0, TopN( 24, Array(new SortOrder { override def expression(): Expression = new DummyNamedReference("name") override def direction(): SortDirection = SortDirection.DESCENDING override def nullOrdering(): NullOrdering = direction().defaultNullOrdering() }) ) ) ) ).createQuery() assertEquals( "WITH $scriptResult AS scriptResult MATCH (p:Person) RETURN p SKIP 0 LIMIT 24", query ) } def versions_and_prefixes(): Array[Array[Any]] = { Array( Array(neo4j(version(4, 4), COMMUNITY), ""), Array(neo4j(version(4, 4), ENTERPRISE), ""), Array(neo4j(version(5, 0), COMMUNITY), ""), Array(neo4j(version(5, 0), ENTERPRISE), ""), Array(neo4j(version(5, 21), COMMUNITY), "CYPHER 5 "), Array(neo4j(version(5, 21), ENTERPRISE), "CYPHER 5 "), Array(neo4j(version(5, 26), COMMUNITY), "CYPHER 5 "), Array(neo4j(version(5, 26), ENTERPRISE), "CYPHER 5 "), Array(neo4j(version(2025, 1), COMMUNITY), "CYPHER 5 "), Array(neo4j(version(2025, 1), ENTERPRISE), "CYPHER 5 ") ) } def neo4j(version: Neo4jVersion, edition: Neo4jEdition): Neo4j = { new Neo4j(version, edition, SELF_MANAGED) } def version(major: Int, minor: Int): Neo4jVersion = { new Neo4jVersion(major, minor, 0) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/SchemaServiceTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.Assert._ import org.junit.Before import org.junit.FixMethodOrder import org.junit.Test import org.junit.runners.MethodSorters import org.neo4j.Closeables.use import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDeploymentType import org.neo4j.caniuse.Neo4jEdition import org.neo4j.caniuse.Neo4jVersion import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import org.neo4j.spark.SparkConnectorScalaBaseTSE import org.neo4j.spark.SparkConnectorScalaSuiteIT import org.neo4j.spark.SparkConnectorScalaSuiteIT.neo4j import org.neo4j.spark.converter.CypherToSparkTypeConverter import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.QueryType import java.util import java.util.UUID @FixMethodOrder(MethodSorters.JVM) class SchemaServiceTSE extends SparkConnectorScalaBaseTSE { @Before def beforeEach(): Unit = { use(SparkConnectorScalaSuiteIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds") .consume() } } @Test def testGetSchemaFromNodeBoolean(): Unit = { initTest("CREATE (p:Person {is_hero: true})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("is_hero", DataTypes.BooleanType))), schema) } @Test def testGetSchemaFromNodeString(): Unit = { initTest("CREATE (p:Person {name: 'John'})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("name", DataTypes.StringType))), schema) } @Test def testGetSchemaFromNodeLong(): Unit = { initTest("CREATE (p:Person {age: 93})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("age", DataTypes.LongType))), schema) } @Test def testGetSchemaFromNodeDouble(): Unit = { initTest("CREATE (p:Person {ratio: 43.120})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("ratio", DataTypes.DoubleType))), schema) } @Test def testGetSchemaFromNodePoint2D(): Unit = { initTest("CREATE (p:Person {location: point({x: 12.32, y: 49.32})})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("location", CypherToSparkTypeConverter.pointType))), schema) } @Test def testGetSchemaFromDate(): Unit = { initTest("CREATE (p:Person {born_on: date('1998-01-05')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("born_on", DataTypes.DateType))), schema) } @Test def testGetSchemaFromDateTime(): Unit = { initTest("CREATE (p:Person {arrived_at: datetime('1998-01-05')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("arrived_at", DataTypes.TimestampType))), schema) } @Test def testGetSchemaFromTime(): Unit = { initTest("CREATE (p:Person {arrived_at: time('125035.556+0100')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("arrived_at", CypherToSparkTypeConverter.timeType))), schema) } @Test def testGetSchemaFromStringArray(): Unit = { initTest("CREATE (p:Person {names: ['John', 'Doe']})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("names", DataTypes.createArrayType(DataTypes.StringType)))), schema ) } @Test def testGetSchemaFromDateArray(): Unit = { initTest("CREATE (p:Person {names: [date('2019-11-19'), date('2019-11-20')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("names", DataTypes.createArrayType(DataTypes.DateType)))), schema ) } @Test def testGetSchemaFromTimestampArray(): Unit = { initTest("CREATE (p:Person {dates: [datetime('2019-11-19'), datetime('2019-11-20')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("dates", DataTypes.createArrayType(DataTypes.TimestampType)))), schema ) } @Test def testGetSchemaFromTimeArray(): Unit = { initTest("CREATE (p:Person {dates: [time('125035.556+0100'), time('125125.556+0100')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("dates", DataTypes.createArrayType(CypherToSparkTypeConverter.timeType)))), schema ) } @Test def testGetSchemaFromIntegerArray(): Unit = { initTest("CREATE (p:Person {ages: [42, 101]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("ages", DataTypes.createArrayType(DataTypes.LongType)))), schema) } @Test def testGetSchemaFromMultipleNodes(): Unit = { initTest( """ CREATE (p1:Person {age: 31, name: 'Jane Doe'}), (p2:Person {name: 'John Doe', age: 33, location: null}), (p3:Person {age: 25, location: point({latitude: 12.12, longitude: 31.13})}) """ ) val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq( StructField("age", DataTypes.LongType), StructField("location", CypherToSparkTypeConverter.pointType), StructField("name", DataTypes.StringType) )), schema ) } private def getExpectedStructType(structFields: Seq[StructField]): StructType = { val additionalFields: Seq[StructField] = Seq( StructField(Neo4jUtil.INTERNAL_LABELS_FIELD, DataTypes.createArrayType(DataTypes.StringType), nullable = true), StructField(Neo4jUtil.INTERNAL_ID_FIELD, DataTypes.LongType, nullable = false) ) StructType(structFields.union(additionalFields).reverse) } private def initTest(query: String): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(query).consume() } ) } private def getSchema(options: java.util.Map[String, String]): StructType = { options.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val driverCache = new DriverCache(neo4jOptions.connection) val schemaService: SchemaService = new SchemaService(neo4j, neo4jOptions, driverCache) val schema: StructType = schemaService.struct() schemaService.close() driverCache.close() schema } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/SchemaServiceTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.junit.Assert.assertEquals import org.junit.Test import org.mockito.Mockito.RETURNS_DEEP_STUBS import org.mockito.Mockito.mock import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDeploymentType import org.neo4j.caniuse.Neo4jEdition import org.neo4j.caniuse.Neo4jVersion import org.neo4j.spark.config.TopN import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import scala.annotation.nowarn import scala.collection.JavaConverters class SchemaServiceTest { @Test def does_not_overflow_when_partition_size_is_over_max_value_of_32bit_integers(): Unit = { // note: _ cannot be used to separate digit groups, as this requires Scala 2.13+ val opts = options( "url" -> "bolt://example.com", "partitions" -> 2.toString, "query.count" -> (2L * 2147483648L).toString, // 2 * (Integer.MAX_VALUE + 1) "query" -> "MERGE (:Node)" ) val schemaService = new SchemaService(neo4j(), opts, mock(classOf[DriverCache], RETURNS_DEEP_STUBS)) val pages = schemaService.skipLimitFromPartition(Some(TopN(1024))) assertEquals(List(0, 1), pages.map(_.partitionNumber).toList) assertEquals(List(0, 2147483648L), pages.map(_.skip).toList) assertEquals(List(2147483648L, 2147483648L), pages.map(_.topN.limit).toList) assertEquals(List(0, 0), pages.map(_.topN.orders.size).toList) } private def options(kv: (String, String)*): Neo4jOptions = { new Neo4jOptions( JavaConverters.mapAsJavaMap(kv.toMap) ) } private def neo4j(): Neo4j = { new Neo4j(new Neo4jVersion(2025, 1, 0), Neo4jEdition.COMMUNITY, Neo4jDeploymentType.SELF_MANAGED) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/service/SchemaServiceWithApocTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.service import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.Assert._ import org.junit.Before import org.junit.FixMethodOrder import org.junit.Test import org.junit.runners.MethodSorters import org.neo4j.Closeables.use import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDeploymentType import org.neo4j.caniuse.Neo4jEdition import org.neo4j.caniuse.Neo4jVersion import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import org.neo4j.spark.SparkConnectorScalaBaseWithApocTSE import org.neo4j.spark.SparkConnectorScalaSuiteWithApocIT import org.neo4j.spark.converter.CypherToSparkTypeConverter import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.QueryType import java.util import java.util.UUID @FixMethodOrder(MethodSorters.JVM) class SchemaServiceWithApocTSE extends SparkConnectorScalaBaseWithApocTSE { @Before def beforeEach(): Unit = { use(SparkConnectorScalaSuiteWithApocIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds") .consume() } } @Test def testGetSchemaFromNodeBoolean(): Unit = { initTest("CREATE (p:Person {is_hero: true})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("is_hero", DataTypes.BooleanType))), schema) } @Test def testGetSchemaFromNodeString(): Unit = { initTest("CREATE (p:Person {name: 'John'})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("name", DataTypes.StringType))), schema) } @Test def testGetSchemaFromNodeLong(): Unit = { initTest("CREATE (p:Person {age: 93})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("age", DataTypes.LongType))), schema) } @Test def testGetSchemaFromNodeDouble(): Unit = { initTest("CREATE (p:Person {ratio: 43.120})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("ratio", DataTypes.DoubleType))), schema) } @Test def testGetSchemaFromNodePoint2D(): Unit = { initTest("CREATE (p:Person {location: point({x: 12.32, y: 49.32})})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("location", CypherToSparkTypeConverter.pointType))), schema) } @Test def testGetSchemaFromDate(): Unit = { initTest("CREATE (p:Person {born_on: date('1998-01-05')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("born_on", DataTypes.DateType))), schema) } @Test def testGetSchemaFromDateTime(): Unit = { initTest("CREATE (p:Person {arrived_at: datetime('1998-01-05')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("arrived_at", DataTypes.TimestampType))), schema) } @Test def testGetSchemaFromTime(): Unit = { initTest("CREATE (p:Person {arrived_at: time('125035.556+0100')})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("arrived_at", CypherToSparkTypeConverter.timeType))), schema) } @Test def testGetSchemaFromStringArray(): Unit = { initTest("CREATE (p:Person {names: ['John', 'Doe']})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("names", DataTypes.createArrayType(DataTypes.StringType)))), schema ) } @Test def testGetSchemaFromDateArray(): Unit = { initTest("CREATE (p:Person {names: [date('2019-11-19'), date('2019-11-20')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("names", DataTypes.createArrayType(DataTypes.DateType)))), schema ) } @Test def testGetSchemaFromTimestampArray(): Unit = { initTest("CREATE (p:Person {dates: [datetime('2019-11-19'), datetime('2019-11-20')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("dates", DataTypes.createArrayType(DataTypes.TimestampType)))), schema ) } @Test def testGetSchemaFromTimeArray(): Unit = { initTest("CREATE (p:Person {dates: [time('125035.556+0100'), time('125125.556+0100')]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq(StructField("dates", DataTypes.createArrayType(CypherToSparkTypeConverter.timeType)))), schema ) } @Test def testGetSchemaFromIntegerArray(): Unit = { initTest("CREATE (p:Person {ages: [42, 101]})") val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals(getExpectedStructType(Seq(StructField("ages", DataTypes.createArrayType(DataTypes.LongType)))), schema) } @Test def testGetSchemaFromMultipleNodes(): Unit = { initTest( """ CREATE (p1:Person {age: 31, name: 'Jane Doe'}), (p2:Person {name: 'John Doe', age: 33, location: null}), (p3:Person {age: 25, location: point({latitude: 12.12, longitude: 31.13})}) """ ) val options: java.util.Map[String, String] = new util.HashMap[String, String]() options.put(QueryType.LABELS.toString.toLowerCase, "Person") val schema = getSchema(options) assertEquals( getExpectedStructType(Seq( StructField("age", DataTypes.LongType), StructField("location", CypherToSparkTypeConverter.pointType), StructField("name", DataTypes.StringType) )), schema ) } private def getExpectedStructType(structFields: Seq[StructField]): StructType = { val additionalFields: Seq[StructField] = Seq( StructField(Neo4jUtil.INTERNAL_LABELS_FIELD, DataTypes.createArrayType(DataTypes.StringType), nullable = true), StructField(Neo4jUtil.INTERNAL_ID_FIELD, DataTypes.LongType, nullable = false) ) StructType(structFields.union(additionalFields).reverse) } private def initTest(query: String): Unit = { SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(query).consume() } ) } private def getSchema(options: java.util.Map[String, String]): StructType = { options.put(Neo4jOptions.URL, SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val driverCache = new DriverCache(neo4jOptions.connection) val neo4j = new Neo4j(new Neo4jVersion(4, 4, 0), Neo4jEdition.ENTERPRISE, Neo4jDeploymentType.SELF_MANAGED) val schemaService: SchemaService = new SchemaService(neo4j, neo4jOptions, driverCache) val schema: StructType = schemaService.struct() schemaService.close() driverCache.close() schema } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/DummyNamedReference.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.spark.sql.connector.expressions.NamedReference class DummyNamedReference(private val fields: String*) extends NamedReference { override def fieldNames(): Array[String] = fields.toArray override def describe(): String = fields.mkString(", ") override def toString: String = describe() } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/Neo4jImplicitsTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.aggregate.Sum import org.apache.spark.sql.sources.And import org.apache.spark.sql.sources.EqualTo import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.Assert import org.junit.Assert._ import org.junit.Test import org.neo4j.spark.util.Neo4jImplicits._ import scala.collection.JavaConverters.mapAsJavaMapConverter import scala.collection.JavaConverters.mapAsScalaMapConverter import scala.collection.JavaConverters.seqAsJavaListConverter import scala.collection.immutable.ListMap class Neo4jImplicitsTest { @Test def `should quote the string` { // given val value = "Test with space" // when val actual = value.quote // then assertEquals(s"`$value`", actual) } @Test def `should quote text that starts with $` { // given val value = "$tring" // when val actual = value.quote // then assertEquals(s"`$value`", actual) } @Test def `should not re-quote the string` { // given val value = "`Test with space`" // when val actual = value.quote // then assertEquals(value, actual) } @Test def `should not quote the string` { // given val value = "Test" // when val actual = value.quote // then assertEquals(value, actual) } @Test def `should return attribute if filter has it` { // given val filter = EqualTo("name", "John") // when val attribute = filter.getAttribute // then assertTrue(attribute.isDefined) } @Test def `should return an empty option if the filter doesn't have an attribute` { // given val filter = And(EqualTo("name", "John"), EqualTo("age", 32)) // when val attribute = filter.getAttribute // then assertFalse(attribute.isDefined) } @Test def `should return the attribute without the entity identifier` { // given val filter = EqualTo("person.address.coords", 32) // when val attribute = filter.getAttributeWithoutEntityName // then assertEquals("address.coords", attribute.get) } @Test def `struct should return true if contains fields`: Unit = { val struct = StructType(Seq( StructField("is_hero", DataTypes.BooleanType), StructField("name", DataTypes.StringType), StructField("fi``(╯°□°)╯︵ ┻━┻eld", DataTypes.StringType) )) assertEquals(0, struct.getMissingFields(Set("is_hero", "name", "fi``(╯°□°)╯︵ ┻━┻eld")).size) } @Test def `struct should return false if not contains fields`: Unit = { val struct = StructType(Seq(StructField("is_hero", DataTypes.BooleanType), StructField("name", DataTypes.StringType))) assertEquals(Set[String]("hero_name"), struct.getMissingFields(Set("is_hero", "hero_name"))) } @Test def `getMissingFields should handle maps`: Unit = { val struct = StructType(Seq( StructField("im", DataTypes.StringType), StructField("im.a", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType)), StructField("im.also.a", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType)), StructField("im.not.a.map", DataTypes.StringType), StructField("fi``(╯°□°)╯︵ ┻━┻eld", DataTypes.StringType) )) val result = struct.getMissingFields(Set( "im.aMap", "`im.also.a`.field", "`im.a`.map", "`im.not.a.map`", "fi``(╯°□°)╯︵ ┻━┻eld" )) assertEquals(Set("im.aMap"), result) } @Test def `groupByCols aggregation should work`: Unit = { val aggField = new NamedReference { override def fieldNames(): Array[String] = Array("foo") override def describe(): String = "foo" } val gbyField = new NamedReference { override def fieldNames(): Array[String] = Array("bar") override def describe(): String = "bar" } val agg = new Aggregation(Array(new Sum(aggField, false)), Array(gbyField)) assertEquals(1, agg.groupByCols().length) assertEquals("bar", agg.groupByCols()(0).describe()) } @Test def `should flatten the map`(): Unit = { val input = Map( "foo" -> "bar", "key" -> Map( "innerKey" -> Map("innerKey2" -> "value") ) ) val expected = Map( "foo" -> "bar", "key.innerKey.innerKey2" -> "value" ) val actual = input.flattenMap() Assert.assertEquals(expected, actual) } @Test def `should not handle collision`(): Unit = { val input = ListMap( "my" -> Map( "inner" -> Map("key" -> 42424242), "inner.key" -> 424242 ), "my.inner" -> Map("key" -> 4242).asJava, "my.inner.key" -> 42 ) val expected = Map( "my.inner.key" -> 42 ) val actual = input.flattenMap() Assert.assertEquals(expected, actual) } @Test def `should handle collision by aggregating values`(): Unit = { val input = ListMap( "my" -> Map( "inner" -> Map("key" -> 42424242), "inner.key" -> 424242 ), "my.inner" -> Map("key" -> 4242).asJava, "my.inner.key" -> 42 ) val expected = Map( "my.inner.key" -> Seq(42424242, 424242, 4242, 42).asJava ) val actual = input.flattenMap(groupDuplicateKeys = true) Assert.assertEquals(expected, actual) } @Test def `should show duplicate keys`(): Unit = { val input = Map( "my" -> Map( "inner" -> Map("key" -> 42424242), "inner.key" -> 424242 ), "my.inner" -> Map("key" -> 4242).asJava, "my.inner.key" -> 42 ) val expected = Seq("my.inner.key", "my.inner.key", "my.inner.key", "my.inner.key") val actual = input.flattenKeys() Assert.assertEquals(expected, actual) } @Test def `should deserialized dotted/stringified map into a nested Java map`(): Unit = { val actual = Map( "graphName" -> "foo", "configuration.number" -> "1", "configuration.string" -> "foo", "configuration.list" -> "['a', 1]", "configuration.map.key" -> "value", "relationshipProjection.LINK.properties.foobar.defaultValue" -> "42.0" ).toNestedJavaMap val expected: java.util.Map[String, Object] = Map( "graphName" -> "foo", "configuration" -> Map( "number" -> 1, "string" -> "foo", "list" -> Seq("a", 1).toList.asJava, "map" -> Map( "key" -> "value" ).asJava ).asJava, "relationshipProjection" -> Map( "LINK" -> Map( "properties" -> Map( "foobar" -> Map("defaultValue" -> 42.0).asJava ).asJava ).asJava ).asJava ).asJava Assert.assertEquals(expected, actual) val ucActual = Map( "graphName" -> "myGraph", "nodeProjection" -> "Website", "relationshipProjection.LINK.indexInverse" -> "true" ).toNestedJavaMap val ucExpected: java.util.Map[String, Object] = Map( "graphName" -> "myGraph", "nodeProjection" -> "Website", "relationshipProjection" -> Map( "LINK" -> Map( "indexInverse" -> true ).asJava ).asJava ).asJava Assert.assertEquals(ucExpected, ucActual) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/Neo4jOptionsIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.junit.Assert.assertEquals import org.junit.Assert.assertNotNull import org.junit.Ignore import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.spark.SparkConnectorScalaSuiteIT import org.neo4j.spark.SparkConnectorScalaSuiteIT.server class Neo4jOptionsIT extends SparkConnectorScalaSuiteIT { @Test def shouldConstructDriver(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, server.getBoltUrl) options.put(Neo4jOptions.AUTH_TYPE, "none") val neo4jOptions = new Neo4jOptions(options) use(neo4jOptions.connection.createDriver()) { driver => assertNotNull(driver) use(driver.session()) { session => assertEquals(1, session.run("RETURN 1").single().get(0).asInt()) } } } @Test def shouldConstructDriverWithResolver(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put( Neo4jOptions.URL, s"neo4j://localhost.localdomain:8888, bolt://localhost.localdomain:9999, ${server.getBoltUrl}" ) options.put(Neo4jOptions.AUTH_TYPE, "none") val neo4jOptions = new Neo4jOptions(options) use(neo4jOptions.connection.createDriver()) { driver => assertNotNull(driver) use(driver.session()) { session => assertEquals(1, session.run("RETURN 1").single().get(0).asInt()) } } } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/Neo4jOptionsTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.junit.Assert._ import org.junit.Test import org.neo4j.driver.AccessMode import org.neo4j.driver.net.ServerAddress import java.net.URI import java.time.Duration import scala.annotation.meta.getter import scala.collection.JavaConverters._ class Neo4jOptionsTest { import org.junit.Rule import org.junit.rules.ExpectedException @(Rule @getter) val _expectedException: ExpectedException = ExpectedException.none @Test def testUrlIsRequired(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(QueryType.QUERY.toString.toLowerCase, "Person") _expectedException.expect(classOf[IllegalArgumentException]) _expectedException.expectMessage("Parameter 'url' is required") new Neo4jOptions(options) } @Test def testRelationshipTableName(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.RELATIONSHIP.toString.toLowerCase, "KNOWS") options.put(Neo4jOptions.RELATIONSHIP_SOURCE_LABELS, "Person") options.put(Neo4jOptions.RELATIONSHIP_TARGET_LABELS, "Answer") val neo4jOptions = new Neo4jOptions(options) assertEquals("table_Person_KNOWS_Answer", neo4jOptions.getTableName) } @Test def testLabelsTableName(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("labels", "Person:Admin") val neo4jOptions = new Neo4jOptions(options) assertEquals("table_Person-Admin", neo4jOptions.getTableName) } @Test def testRelationshipNodeModesAreCaseInsensitive(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.RELATIONSHIP.toString.toLowerCase, "KNOWS") options.put(Neo4jOptions.RELATIONSHIP_SAVE_STRATEGY, "nAtIve") options.put(Neo4jOptions.RELATIONSHIP_SOURCE_SAVE_MODE, "Errorifexists") options.put(Neo4jOptions.RELATIONSHIP_TARGET_SAVE_MODE, "overwrite") val neo4jOptions = new Neo4jOptions(options) assertEquals(RelationshipSaveStrategy.NATIVE, neo4jOptions.relationshipMetadata.saveStrategy) assertEquals(NodeSaveMode.ErrorIfExists, neo4jOptions.relationshipMetadata.sourceSaveMode) assertEquals(NodeSaveMode.Overwrite, neo4jOptions.relationshipMetadata.targetSaveMode) } @Test def testRelationshipWriteStrategyIsNotPresentShouldThrowException(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, "PERSON") options.put("relationship.save.strategy", "nope") _expectedException.expect(classOf[NoSuchElementException]) _expectedException.expectMessage("No value found for 'NOPE'") new Neo4jOptions(options) } @Test def testQueryShouldHaveQueryType(): Unit = { val query: String = "MATCH n RETURN n" val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.QUERY.toString.toLowerCase, query) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals(QueryType.QUERY, neo4jOptions.query.queryType) assertEquals(query, neo4jOptions.query.value) } @Test def testNodeShouldHaveLabelType(): Unit = { val label: String = "Person" val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, label) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals(QueryType.LABELS, neo4jOptions.query.queryType) assertEquals(label, neo4jOptions.query.value) } @Test def testRelationshipShouldHaveRelationshipType(): Unit = { val relationship: String = "KNOWS" val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.LABELS.toString.toLowerCase, relationship) val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals(QueryType.LABELS, neo4jOptions.query.queryType) assertEquals(relationship, neo4jOptions.query.value) } @Test def testPushDownColumnIsDisabled(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put("pushdown.columns.enabled", "false") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertFalse(neo4jOptions.pushdownColumnsEnabled) } @Test def testDriverDefaults(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "bolt://localhost") options.put(QueryType.QUERY.toString.toLowerCase, "MATCH n RETURN n") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals("", neo4jOptions.session.database) assertEquals(AccessMode.READ, neo4jOptions.session.accessMode) assertEquals("basic", neo4jOptions.connection.auth) assertEquals(Neo4jOptions.DEFAULT_AUTH_PARAMETERS, neo4jOptions.connection.authParameters) assertEquals(false, neo4jOptions.connection.encryption) assertEquals(None, neo4jOptions.connection.trustStrategy) assertEquals("", neo4jOptions.connection.certificatePath) assertEquals(Neo4jOptions.DEFAULT_CONNECTION_MAX_LIFETIME_MSECS, neo4jOptions.connection.lifetime) assertEquals(-1, neo4jOptions.connection.acquisitionTimeout) assertEquals(-1, neo4jOptions.connection.connectionTimeout) assertEquals( Neo4jOptions.DEFAULT_CONNECTION_LIVENESS_CHECK_TIMEOUT_MSECS, neo4jOptions.connection.livenessCheckTimeout ) assertEquals(RelationshipSaveStrategy.NATIVE, neo4jOptions.relationshipMetadata.saveStrategy) assertTrue(neo4jOptions.pushdownFiltersEnabled) } @Test def testApocConfiguration(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put("apoc.meta.nodeTypeProperties", """{"nodeLabels": ["Label"], "mandatory": false}""") options.put(Neo4jOptions.URL, "bolt://localhost") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val expected = Map("apoc.meta.nodeTypeProperties" -> Map( "nodeLabels" -> Seq("Label").asJava, "mandatory" -> false )) assertEquals(neo4jOptions.apocConfig.procedureConfigMap, expected) } @Test def testUnexistingProperty(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put("relationship.properties", null) options.put(Neo4jOptions.URL, "bolt://localhost") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals(neo4jOptions.relationshipMetadata.properties, None) } @Test def testUrls(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "neo4j://localhost, neo4j://foo.bar:7687, neo4j://foo.bar.baz:7783") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) val (baseUrl, resolvers) = neo4jOptions.connection.connectionUrls assertEquals(URI.create("neo4j://localhost"), baseUrl) assertEquals(Set(ServerAddress.of("foo.bar", 7687), ServerAddress.of("foo.bar.baz", 7783)), resolvers) } @Test def testGdsProperties(): Unit = { val options: java.util.Map[String, String] = new java.util.HashMap[String, String]() options.put(Neo4jOptions.URL, "neo4j://localhost,neo4j://foo.bar,neo4j://foo.bar.baz:7783") options.put("gds", "gds.pageRank.stream") options.put("gds.graphName", "myGraph") options.put("gds.configuration.concurrency", "2") val neo4jOptions: Neo4jOptions = new Neo4jOptions(options) assertEquals(QueryType.GDS, neo4jOptions.query.queryType) assertEquals("gds.pageRank.stream", neo4jOptions.query.value) assertEquals( Map( "graphName" -> "myGraph", "configuration" -> Map("concurrency" -> 2).asJava ).asJava, neo4jOptions.gdsMetadata.parameters ) } @Test def testTransactionTimeout(): Unit = { // Given a Neo4j options with transaction timeout set val rawOptions = new java.util.HashMap[String, String]() rawOptions.put(Neo4jOptions.URL, "neo4j://localhost,neo4j://foo.bar,neo4j://foo.bar.baz:7783") rawOptions.put("db.transaction.timeout", "1000") val neo4jOptions = new Neo4jOptions(rawOptions) // When it converts to TransactionConfig val transactionConfig = neo4jOptions.toNeo4jTransactionConfig // Then it has the correct duration assertEquals(Duration.ofMillis(1000), transactionConfig.timeout()) } @Test def testDefaultTransactionTimeout(): Unit = { // Given a Neo4j options with no explicit transaction timeout set val rawOptions = new java.util.HashMap[String, String]() rawOptions.put(Neo4jOptions.URL, "neo4j://localhost,neo4j://foo.bar,neo4j://foo.bar.baz:7783") val neo4jOptions = new Neo4jOptions(rawOptions) // When it converts to TransactionConfig val transactionConfig = neo4jOptions.toNeo4jTransactionConfig // Then it is not set assertNull(transactionConfig.timeout()) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/Neo4jUtilTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.commons.lang3.StringUtils import org.junit.Assert import org.junit.Test class Neo4jUtilTest { @Test def testSafetyCloseShouldNotFailWithNull(): Unit = { Neo4jUtil.closeSafely(null) } @Test def testConnectorEnv(): Unit = { val expected = if (StringUtils.isNotBlank(System.getenv("DATABRICKS_RUNTIME_VERSION"))) { "databricks" } else { "spark" } val actual = Neo4jUtil.connectorEnv Assert.assertEquals(expected, actual) } @Test def testConnectorEnvForCustom(): Unit = { System.setProperty("neo4j.spark.platform", "abc") val actual = Neo4jUtil.connectorEnv Assert.assertEquals("abc", actual) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/ValidationsIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.hamcrest.CoreMatchers import org.junit.Rule import org.junit.Test import org.junit.rules.ExpectedException import org.neo4j.driver.AccessMode import org.neo4j.spark.SparkConnectorScalaSuiteIT import org.neo4j.spark.SparkConnectorScalaSuiteIT.neo4j import org.neo4j.spark.TestUtil import java.util.regex.Pattern import scala.annotation.meta.getter class ValidationsIT extends SparkConnectorScalaSuiteIT { @(Rule @getter) val expectedException: ExpectedException = ExpectedException.none @Test def testReadQueryShouldBeSyntacticallyInvalid(): Unit = { // then expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage( CoreMatchers.containsString("Query not compiled for the following exception: ClientException: Invalid input ") ) val query = "MATCH (f{) RETURN f" expectedException.expectMessage(CoreMatchers.containsString(query)) // given val readOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() readOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) readOpts.put("query", query) // when Validations.validate(ValidateRead(neo4j, new Neo4jOptions(readOpts), "1")) } @Test def testReadQueryShouldBeSemanticallyInvalid(): Unit = { // then val query = "MERGE (n:TestNode{id: 1}) RETURN n" expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage( s"Invalid query `$query` because the accepted types are [READ_ONLY], but the actual type is READ_WRITE" ) // given val readOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() readOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) readOpts.put("query", query) // when Validations.validate(ValidateRead(neo4j, new Neo4jOptions(readOpts), "1")) } @Test def testReadQueryCountBeSyntacticallyInvalid(): Unit = { // then val query = "MATCH (f{) RETURN f" expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage(CoreMatchers.containsString( "Query count not compiled for the following exception: ClientException: Invalid input " )) expectedException.expectMessage(CoreMatchers.containsString(s"EXPLAIN $query")) // given val readOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() readOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) readOpts.put("query", "MATCH (f) RETURN f") readOpts.put("query.count", query) // when Validations.validate(ValidateRead(neo4j, new Neo4jOptions(readOpts), "1")) } @Test def testScriptQueryCountShouldContainAnInvalidQuery(): Unit = { // then expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage( CoreMatchers.containsString("The following queries inside the `script` are not valid,") ) expectedException.expectMessage( CoreMatchers.containsString("Query not compiled for the following exception: ClientException: Invalid input ") ) expectedException.expectMessage(CoreMatchers.containsString("EXPLAIN RETUR 2 AS two")) // given val readOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() readOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) readOpts.put("query", "MATCH (f) RETURN f") readOpts.put("script", "RETURN 1 AS one; RETUR 2 AS two; RETURN 3 AS three") // when Validations.validate(ValidateRead(neo4j, new Neo4jOptions(readOpts), "1")) } @Test def testWriteQueryShouldBeSyntacticallyInvalid(): Unit = { // then val query = "MERGE (f{) RETURN f" expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage( CoreMatchers.containsString("Query not compiled for the following exception: ClientException: Invalid input ") ) expectedException.expectMessage(CoreMatchers.containsString(query)) // given val writeOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() writeOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) writeOpts.put(Neo4jOptions.ACCESS_MODE, AccessMode.WRITE.toString) writeOpts.put("query", query) // when Validations.validate(ValidateWrite(neo4j, new Neo4jOptions(writeOpts), "1", null)) } @Test def testWriteQueryShouldBeSemanticallyInvalid(): Unit = { // then val query = "MATCH (n:TestNode{id: 1}) RETURN n" expectedException.expect(classOf[IllegalArgumentException]) expectedException.expectMessage( s"Invalid query `$query` because the accepted types are [WRITE_ONLY, READ_WRITE], but the actual type is READ_ONLY" ) // given val writeOpts: java.util.Map[String, String] = new java.util.HashMap[String, String]() writeOpts.put(Neo4jOptions.URL, SparkConnectorScalaSuiteIT.server.getBoltUrl) writeOpts.put(Neo4jOptions.ACCESS_MODE, AccessMode.WRITE.toString) writeOpts.put("query", query) // when Validations.validate(ValidateWrite(neo4j, new Neo4jOptions(writeOpts), "1", null)) } } ================================================ FILE: common/src/test/scala/org/neo4j/spark/util/ValidationsTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.util import org.apache.spark.sql.SparkSession import org.junit import org.junit.Assert.assertEquals import org.junit.Test import org.neo4j.spark.SparkConnectorScalaBaseTSE class ValidationsTest extends SparkConnectorScalaBaseTSE { @Test def testVersionThrowsExceptionSparkVersionIsNotSupported(): Unit = { val sparkVersion = SparkSession.getActiveSession .map { _.version } .getOrElse("UNKNOWN") try { Validations.validate(ValidateSparkMinVersion("3.10000")) fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } catch { case e: IllegalArgumentException => assertEquals( s"""Your current Spark version $sparkVersion is not supported by the current connector. |Please visit https://neo4j.com/developer/spark/overview/#_spark_compatibility to know which connector version you need. |""".stripMargin, e.getMessage ) case e: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${e.getClass} instead") } } @Test def testVersionShouldBeValid(): Unit = { val fullVersion = SparkSession .getDefaultSession .map(_.version) .getOrElse("3.2") val baseVersion = fullVersion .split("\\.") .take(2) .mkString(".") Validations.validate(ValidateSparkMinVersion(s"$baseVersion.*")) Validations.validate(ValidateSparkMinVersion(fullVersion)) Validations.validate(ValidateSparkMinVersion(s"$fullVersion-amzn-0")) } @Test def testVersionShouldValidateTheVersion(): Unit = { val version = ValidateSparkMinVersion("2.3.0") junit.Assert.assertTrue(version.isSupported("2.3.0-amzn-1")) junit.Assert.assertTrue(version.isSupported("2.3.1-amzn-1")) junit.Assert.assertTrue(version.isSupported("3.3.0-amzn-1")) junit.Assert.assertTrue(version.isSupported("3.3.0")) junit.Assert.assertTrue(version.isSupported("3.1.0")) junit.Assert.assertTrue(version.isSupported("3.2.0")) junit.Assert.assertFalse(version.isSupported("2.2.10")) } } ================================================ FILE: dangerfile.mjs ================================================ import load from '@commitlint/load'; import lint from '@commitlint/lint'; const minPRDescriptionLength = 10; // Utility functions const processReport = (type, report, warnOnly = false) => { if (report.warnings.length > 0) { warn( `${type} '${report.input}': ${report.warnings .map((w) => w.message) .join(', ')}`, ); } if (report.errors.length > 0) { const reportFn = warnOnly ? warn : fail; reportFn( `${type} '${report.input}': ${report.errors .map((e) => e.message) .join(', ')}`, ); } return report.valid || warnOnly ? Promise.resolve() : Promise.reject(); }; const reportCommitMessage = (report) => processReport('Commit Message', report, true); const reportPRTitle = (report) => processReport('PR Title', report, false); const lintMessage = (message, opts, reporter) => lint( message, opts.rules, opts.parserPreset ? {parserOpts: opts.parserPreset.parserOpts} : {}, ).then(reporter); const pr = danger.github.pr; // check commit messages and PR name schedule( Promise.all([ load({}, {file: './.commitlintrc.json', cwd: process.cwd()}).then( (opts) => Promise.all( danger.git.commits .map((c) => c.message) .map((m) => lintMessage(m, opts, reportCommitMessage)), ) .catch(() => markdown( '> All commits should follow ' + '[Conventional commits](https://cheatography.com/albelop/cheat-sheets/conventional-commits). ' + 'It seems some of the commit messages are not following those rules, please fix them.', ), ) .then(() => lintMessage(pr.title, opts, reportPRTitle)) .catch(() => markdown( '> Pull request title should follow ' + '[Conventional commits](https://cheatography.com/albelop/cheat-sheets/conventional-commits).', ), ), ), new Promise((resolve, reject) => { // No PR is too small to include a description of why you made a change if (pr.body.length < minPRDescriptionLength) { warn(`:exclamation: Please include a description of your PR changes.`); markdown( '> Pull request should have a description of the underlying changes.', ); } }), ]), ); ================================================ FILE: examples/neo4j_data_engineering.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "source": [ "Open this notebook in Google Colab \n", " \"Open\n", "" ], "metadata": { "id": "EhTThKJMxDCy" } }, { "cell_type": "markdown", "metadata": { "id": "7Nvb-_bYx359" }, "source": [ "# Example of a Simple data engineering workflow with Neo4j and Spark" ] }, { "cell_type": "markdown", "source": [ "This notebook contains a set of examples that explains how to extract insights from data using the Neo4j Connector for Apache Spark in a Data Engineering workflow with [AuraDB](https://neo4j.com/docs/aura/auradb/) our fully managed version of Neo4j database.\n", "\n", "The notebooks will enable you to test your knowledge with a set of exercises after each section.\n", "\n", "If you have any questions or problems feel free to write a post in the [Neo4j community forum](https://community.neo4j.com/) or in [Discord](https://discord.com/invite/neo4j).\n", "\n", "If you want more exercises feel free to open an issue in the [GitHub repository](https://github.com/neo4j/neo4j-spark-connector).\n", "\n", "Enjoy!" ], "metadata": { "id": "e0bo6ido8tL7" } }, { "cell_type": "markdown", "metadata": { "id": "hXwkjQMnMXED" }, "source": [ "### Configure the Spark Environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BhZwh-RAz6Bo" }, "outputs": [], "source": [ "!apt-get install openjdk-17-jdk-headless -qq > /dev/null" ] }, { "cell_type": "code", "source": [ "spark_version = '3.3.4'" ], "metadata": { "id": "gmEzhrux7Jek" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!wget -q https://dlcdn.apache.org/spark/spark-$spark_version/spark-$spark_version-bin-hadoop3.tgz" ], "metadata": { "id": "Ya6Nj_u3vdTL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A3gsnSHl0F99" }, "outputs": [], "source": [ "!tar xf spark-$spark_version-bin-hadoop3.tgz" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hSBQWKs90vSx" }, "outputs": [], "source": [ "!pip install -q findspark" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tnW0a1Gj080k" }, "outputs": [], "source": [ "import os\n", "os.environ[\"JAVA_HOME\"] = \"/usr/lib/jvm/java-17-openjdk-amd64\"\n", "os.environ[\"SPARK_HOME\"] = f\"/content/spark-{spark_version}-bin-hadoop3\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dlUBSezK1DpZ" }, "outputs": [], "source": [ "import findspark\n", "findspark.init()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rd5KWGQiOVDV" }, "outputs": [], "source": [ "neo4j_url = \"\" # put your neo4j url here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uXbi_82KOTzU" }, "outputs": [], "source": [ "neo4j_user = \"neo4j\" # put your neo4j user here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sw50wjxxOUqt" }, "outputs": [], "source": [ "neo4j_password = \"\" # put your neo4j password here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dOUJ-W871Tur" }, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "spark = (SparkSession.builder\n", " .master('local[*]')\n", " .appName('Data engineering workflow with Neo4j and Spark')\n", " .config('spark.ui.port', '4050')\n", " # Just to show dataframes as tables\n", " .config('spark.sql.repl.eagerEval.enabled', True)\n", " .config('spark.jars.packages', 'org.neo4j:neo4j-connector-apache-spark_2.12:5.1.0_for_spark_3')\n", " # As we're using always the same database instance we'll\n", " # define them as global variables\n", " # so we don't need to repeat them each time\n", " .config(\"neo4j.url\", neo4j_url)\n", " .config(\"neo4j.authentication.type\", \"basic\")\n", " .config(\"neo4j.authentication.basic.username\", neo4j_user)\n", " .config(\"neo4j.authentication.basic.password\", neo4j_password)\n", " .getOrCreate())\n", "spark" ] }, { "cell_type": "markdown", "source": [ "\n", "## Exercises prerequisite\n", "\n", "In this notebook we and going to test your knowledge. Some of the exercises require the Neo4j Python driver to check if the exercises are being solved correctly.\n", "\n", "*Neo4j Python Driver is required only for verifying the exercises when you persist data from Spark to Neo4j*\n", "\n", "**It's not required by the Spark connector!!!**\n", "\n", "We'll use [Cy2Py](https://github.com/conker84/cy2py), a Jupyter extension that easily allows you to connect to Neo4j and visualize data from Jupyter notebooks.\n", "For a detailed instruction about how to use it please dive into [this example](https://github.com/conker84/cy2py/blob/main/examples/Neo4j_Crime_Investigation_Dataset.ipynb)" ], "metadata": { "id": "b6_YNZnZ5GdT" } }, { "cell_type": "code", "source": [ "!pip install -q cy2py" ], "metadata": { "id": "f5ZZJylo5Bbz" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "CsnO4C9X7vK0" }, "source": [ "### Configure an Aura instance\n", "\n", "
\n", "

Neo4j Aura DB is a fully managed cloud service: The zero-admin, always-on graph database for cloud developers.

\n", "\n", "Create a [free instance](https://console.neo4j.io/?ref=aura-lp&mpp=4bfb2414ab973c741b6f067bf06d5575&mpid=17f40ce03ac883-0f09bb214466c1-37677109-1ea000-17f40ce03ad975&_gl=1*ql4f6s*_ga*MTc2OTMwNjEwMy4xNjQ5NDI3MDE0*_ga_DL38Q8KGQC*MTY1MzQxMDQzMC43OS4xLjE2NTM0MTA3MjQuMA..&_ga=2.136543024.1659283742.1653295079-1769306103.1649427014&_gac=1.216269284.1653306922.CjwKCAjw4ayUBhA4EiwATWyBrl6dN0oaH9_btCfvzdhi77ieNP07GAkOYuz7wx9QEewBnG_FUIMg8xoCgLsQAvD_BwE)\n", "\n", "
" ] }, { "cell_type": "markdown", "source": [ "let's load the extension" ], "metadata": { "id": "uKYEPEgOcG2b" } }, { "cell_type": "code", "source": [ "%load_ext cy2py" ], "metadata": { "id": "38EeXF6icKOK" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Populate the database\n", "\n", "To perform this section go in the Neo4j Brower of your aura instance and paste the following query:\n", "\n", "
\n", "\n", "Show the Cypher query\n", "\n", "\n", "```cypher\n", "CREATE (TheMatrix:Movie {title:'The Matrix', released:1999, tagline:'Welcome to the Real World'})\n", "CREATE (Keanu:Person {name:'Keanu Reeves', born:1964})\n", "CREATE (Carrie:Person {name:'Carrie-Anne Moss', born:1967})\n", "CREATE (Laurence:Person {name:'Laurence Fishburne', born:1961})\n", "CREATE (Hugo:Person {name:'Hugo Weaving', born:1960})\n", "CREATE (LillyW:Person {name:'Lilly Wachowski', born:1967})\n", "CREATE (LanaW:Person {name:'Lana Wachowski', born:1965})\n", "CREATE (JoelS:Person {name:'Joel Silver', born:1952})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrix),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrix),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrix),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrix),\n", "(LillyW)-[:DIRECTED]->(TheMatrix),\n", "(LanaW)-[:DIRECTED]->(TheMatrix),\n", "(JoelS)-[:PRODUCED]->(TheMatrix)\n", "\n", "CREATE (Emil:Person {name:\"Emil Eifrem\", born:1978})\n", "CREATE (Emil)-[:ACTED_IN {roles:[\"Emil\"]}]->(TheMatrix)\n", "\n", "CREATE (TheMatrixReloaded:Movie {title:'The Matrix Reloaded', released:2003, tagline:'Free your mind'})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixReloaded),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixReloaded),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixReloaded),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixReloaded),\n", "(LillyW)-[:DIRECTED]->(TheMatrixReloaded),\n", "(LanaW)-[:DIRECTED]->(TheMatrixReloaded),\n", "(JoelS)-[:PRODUCED]->(TheMatrixReloaded)\n", "\n", "CREATE (TheMatrixRevolutions:Movie {title:'The Matrix Revolutions', released:2003, tagline:'Everything that has a beginning has an end'})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixRevolutions),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixRevolutions),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixRevolutions),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixRevolutions),\n", "(LillyW)-[:DIRECTED]->(TheMatrixRevolutions),\n", "(LanaW)-[:DIRECTED]->(TheMatrixRevolutions),\n", "(JoelS)-[:PRODUCED]->(TheMatrixRevolutions)\n", "\n", "CREATE (TheDevilsAdvocate:Movie {title:\"The Devil's Advocate\", released:1997, tagline:'Evil has its winning ways'})\n", "CREATE (Charlize:Person {name:'Charlize Theron', born:1975})\n", "CREATE (Al:Person {name:'Al Pacino', born:1940})\n", "CREATE (Taylor:Person {name:'Taylor Hackford', born:1944})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Kevin Lomax']}]->(TheDevilsAdvocate),\n", "(Charlize)-[:ACTED_IN {roles:['Mary Ann Lomax']}]->(TheDevilsAdvocate),\n", "(Al)-[:ACTED_IN {roles:['John Milton']}]->(TheDevilsAdvocate),\n", "(Taylor)-[:DIRECTED]->(TheDevilsAdvocate)\n", "\n", "CREATE (AFewGoodMen:Movie {title:\"A Few Good Men\", released:1992, tagline:\"In the heart of the nation's capital, in a courthouse of the U.S. government, one man will stop at nothing to keep his honor, and one will stop at nothing to find the truth.\"})\n", "CREATE (TomC:Person {name:'Tom Cruise', born:1962})\n", "CREATE (JackN:Person {name:'Jack Nicholson', born:1937})\n", "CREATE (DemiM:Person {name:'Demi Moore', born:1962})\n", "CREATE (KevinB:Person {name:'Kevin Bacon', born:1958})\n", "CREATE (KieferS:Person {name:'Kiefer Sutherland', born:1966})\n", "CREATE (NoahW:Person {name:'Noah Wyle', born:1971})\n", "CREATE (CubaG:Person {name:'Cuba Gooding Jr.', born:1968})\n", "CREATE (KevinP:Person {name:'Kevin Pollak', born:1957})\n", "CREATE (JTW:Person {name:'J.T. Walsh', born:1943})\n", "CREATE (JamesM:Person {name:'James Marshall', born:1967})\n", "CREATE (ChristopherG:Person {name:'Christopher Guest', born:1948})\n", "CREATE (RobR:Person {name:'Rob Reiner', born:1947})\n", "CREATE (AaronS:Person {name:'Aaron Sorkin', born:1961})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Lt. Daniel Kaffee']}]->(AFewGoodMen),\n", "(JackN)-[:ACTED_IN {roles:['Col. Nathan R. Jessup']}]->(AFewGoodMen),\n", "(DemiM)-[:ACTED_IN {roles:['Lt. Cdr. JoAnne Galloway']}]->(AFewGoodMen),\n", "(KevinB)-[:ACTED_IN {roles:['Capt. Jack Ross']}]->(AFewGoodMen),\n", "(KieferS)-[:ACTED_IN {roles:['Lt. Jonathan Kendrick']}]->(AFewGoodMen),\n", "(NoahW)-[:ACTED_IN {roles:['Cpl. Jeffrey Barnes']}]->(AFewGoodMen),\n", "(CubaG)-[:ACTED_IN {roles:['Cpl. Carl Hammaker']}]->(AFewGoodMen),\n", "(KevinP)-[:ACTED_IN {roles:['Lt. Sam Weinberg']}]->(AFewGoodMen),\n", "(JTW)-[:ACTED_IN {roles:['Lt. Col. Matthew Andrew Markinson']}]->(AFewGoodMen),\n", "(JamesM)-[:ACTED_IN {roles:['Pfc. Louden Downey']}]->(AFewGoodMen),\n", "(ChristopherG)-[:ACTED_IN {roles:['Dr. Stone']}]->(AFewGoodMen),\n", "(AaronS)-[:ACTED_IN {roles:['Man in Bar']}]->(AFewGoodMen),\n", "(RobR)-[:DIRECTED]->(AFewGoodMen),\n", "(AaronS)-[:WROTE]->(AFewGoodMen)\n", "\n", "CREATE (TopGun:Movie {title:\"Top Gun\", released:1986, tagline:'I feel the need, the need for speed.'})\n", "CREATE (KellyM:Person {name:'Kelly McGillis', born:1957})\n", "CREATE (ValK:Person {name:'Val Kilmer', born:1959})\n", "CREATE (AnthonyE:Person {name:'Anthony Edwards', born:1962})\n", "CREATE (TomS:Person {name:'Tom Skerritt', born:1933})\n", "CREATE (MegR:Person {name:'Meg Ryan', born:1961})\n", "CREATE (TonyS:Person {name:'Tony Scott', born:1944})\n", "CREATE (JimC:Person {name:'Jim Cash', born:1941})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Maverick']}]->(TopGun),\n", "(KellyM)-[:ACTED_IN {roles:['Charlie']}]->(TopGun),\n", "(ValK)-[:ACTED_IN {roles:['Iceman']}]->(TopGun),\n", "(AnthonyE)-[:ACTED_IN {roles:['Goose']}]->(TopGun),\n", "(TomS)-[:ACTED_IN {roles:['Viper']}]->(TopGun),\n", "(MegR)-[:ACTED_IN {roles:['Carole']}]->(TopGun),\n", "(TonyS)-[:DIRECTED]->(TopGun),\n", "(JimC)-[:WROTE]->(TopGun)\n", "\n", "CREATE (JerryMaguire:Movie {title:'Jerry Maguire', released:2000, tagline:'The rest of his life begins now.'})\n", "CREATE (ReneeZ:Person {name:'Renee Zellweger', born:1969})\n", "CREATE (KellyP:Person {name:'Kelly Preston', born:1962})\n", "CREATE (JerryO:Person {name:\"Jerry O'Connell\", born:1974})\n", "CREATE (JayM:Person {name:'Jay Mohr', born:1970})\n", "CREATE (BonnieH:Person {name:'Bonnie Hunt', born:1961})\n", "CREATE (ReginaK:Person {name:'Regina King', born:1971})\n", "CREATE (JonathanL:Person {name:'Jonathan Lipnicki', born:1996})\n", "CREATE (CameronC:Person {name:'Cameron Crowe', born:1957})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Jerry Maguire']}]->(JerryMaguire),\n", "(CubaG)-[:ACTED_IN {roles:['Rod Tidwell']}]->(JerryMaguire),\n", "(ReneeZ)-[:ACTED_IN {roles:['Dorothy Boyd']}]->(JerryMaguire),\n", "(KellyP)-[:ACTED_IN {roles:['Avery Bishop']}]->(JerryMaguire),\n", "(JerryO)-[:ACTED_IN {roles:['Frank Cushman']}]->(JerryMaguire),\n", "(JayM)-[:ACTED_IN {roles:['Bob Sugar']}]->(JerryMaguire),\n", "(BonnieH)-[:ACTED_IN {roles:['Laurel Boyd']}]->(JerryMaguire),\n", "(ReginaK)-[:ACTED_IN {roles:['Marcee Tidwell']}]->(JerryMaguire),\n", "(JonathanL)-[:ACTED_IN {roles:['Ray Boyd']}]->(JerryMaguire),\n", "(CameronC)-[:DIRECTED]->(JerryMaguire),\n", "(CameronC)-[:PRODUCED]->(JerryMaguire),\n", "(CameronC)-[:WROTE]->(JerryMaguire)\n", "\n", "CREATE (StandByMe:Movie {title:\"Stand By Me\", released:1986, tagline:\"For some, it's the last real taste of innocence, and the first real taste of life. But for everyone, it's the time that memories are made of.\"})\n", "CREATE (RiverP:Person {name:'River Phoenix', born:1970})\n", "CREATE (CoreyF:Person {name:'Corey Feldman', born:1971})\n", "CREATE (WilW:Person {name:'Wil Wheaton', born:1972})\n", "CREATE (JohnC:Person {name:'John Cusack', born:1966})\n", "CREATE (MarshallB:Person {name:'Marshall Bell', born:1942})\n", "CREATE\n", "(WilW)-[:ACTED_IN {roles:['Gordie Lachance']}]->(StandByMe),\n", "(RiverP)-[:ACTED_IN {roles:['Chris Chambers']}]->(StandByMe),\n", "(JerryO)-[:ACTED_IN {roles:['Vern Tessio']}]->(StandByMe),\n", "(CoreyF)-[:ACTED_IN {roles:['Teddy Duchamp']}]->(StandByMe),\n", "(JohnC)-[:ACTED_IN {roles:['Denny Lachance']}]->(StandByMe),\n", "(KieferS)-[:ACTED_IN {roles:['Ace Merrill']}]->(StandByMe),\n", "(MarshallB)-[:ACTED_IN {roles:['Mr. Lachance']}]->(StandByMe),\n", "(RobR)-[:DIRECTED]->(StandByMe)\n", "\n", "CREATE (AsGoodAsItGets:Movie {title:'As Good as It Gets', released:1997, tagline:'A comedy from the heart that goes for the throat.'})\n", "CREATE (HelenH:Person {name:'Helen Hunt', born:1963})\n", "CREATE (GregK:Person {name:'Greg Kinnear', born:1963})\n", "CREATE (JamesB:Person {name:'James L. Brooks', born:1940})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Melvin Udall']}]->(AsGoodAsItGets),\n", "(HelenH)-[:ACTED_IN {roles:['Carol Connelly']}]->(AsGoodAsItGets),\n", "(GregK)-[:ACTED_IN {roles:['Simon Bishop']}]->(AsGoodAsItGets),\n", "(CubaG)-[:ACTED_IN {roles:['Frank Sachs']}]->(AsGoodAsItGets),\n", "(JamesB)-[:DIRECTED]->(AsGoodAsItGets)\n", "\n", "CREATE (WhatDreamsMayCome:Movie {title:'What Dreams May Come', released:1998, tagline:'After life there is more. The end is just the beginning.'})\n", "CREATE (AnnabellaS:Person {name:'Annabella Sciorra', born:1960})\n", "CREATE (MaxS:Person {name:'Max von Sydow', born:1929})\n", "CREATE (WernerH:Person {name:'Werner Herzog', born:1942})\n", "CREATE (Robin:Person {name:'Robin Williams', born:1951})\n", "CREATE (VincentW:Person {name:'Vincent Ward', born:1956})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Chris Nielsen']}]->(WhatDreamsMayCome),\n", "(CubaG)-[:ACTED_IN {roles:['Albert Lewis']}]->(WhatDreamsMayCome),\n", "(AnnabellaS)-[:ACTED_IN {roles:['Annie Collins-Nielsen']}]->(WhatDreamsMayCome),\n", "(MaxS)-[:ACTED_IN {roles:['The Tracker']}]->(WhatDreamsMayCome),\n", "(WernerH)-[:ACTED_IN {roles:['The Face']}]->(WhatDreamsMayCome),\n", "(VincentW)-[:DIRECTED]->(WhatDreamsMayCome)\n", "\n", "CREATE (SnowFallingonCedars:Movie {title:'Snow Falling on Cedars', released:1999, tagline:'First loves last. Forever.'})\n", "CREATE (EthanH:Person {name:'Ethan Hawke', born:1970})\n", "CREATE (RickY:Person {name:'Rick Yune', born:1971})\n", "CREATE (JamesC:Person {name:'James Cromwell', born:1940})\n", "CREATE (ScottH:Person {name:'Scott Hicks', born:1953})\n", "CREATE\n", "(EthanH)-[:ACTED_IN {roles:['Ishmael Chambers']}]->(SnowFallingonCedars),\n", "(RickY)-[:ACTED_IN {roles:['Kazuo Miyamoto']}]->(SnowFallingonCedars),\n", "(MaxS)-[:ACTED_IN {roles:['Nels Gudmundsson']}]->(SnowFallingonCedars),\n", "(JamesC)-[:ACTED_IN {roles:['Judge Fielding']}]->(SnowFallingonCedars),\n", "(ScottH)-[:DIRECTED]->(SnowFallingonCedars)\n", "\n", "CREATE (YouveGotMail:Movie {title:\"You've Got Mail\", released:1998, tagline:'At odds in life... in love on-line.'})\n", "CREATE (ParkerP:Person {name:'Parker Posey', born:1968})\n", "CREATE (DaveC:Person {name:'Dave Chappelle', born:1973})\n", "CREATE (SteveZ:Person {name:'Steve Zahn', born:1967})\n", "CREATE (TomH:Person {name:'Tom Hanks', born:1956})\n", "CREATE (NoraE:Person {name:'Nora Ephron', born:1941})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Joe Fox']}]->(YouveGotMail),\n", "(MegR)-[:ACTED_IN {roles:['Kathleen Kelly']}]->(YouveGotMail),\n", "(GregK)-[:ACTED_IN {roles:['Frank Navasky']}]->(YouveGotMail),\n", "(ParkerP)-[:ACTED_IN {roles:['Patricia Eden']}]->(YouveGotMail),\n", "(DaveC)-[:ACTED_IN {roles:['Kevin Jackson']}]->(YouveGotMail),\n", "(SteveZ)-[:ACTED_IN {roles:['George Pappas']}]->(YouveGotMail),\n", "(NoraE)-[:DIRECTED]->(YouveGotMail)\n", "\n", "CREATE (SleeplessInSeattle:Movie {title:'Sleepless in Seattle', released:1993, tagline:'What if someone you never met, someone you never saw, someone you never knew was the only someone for you?'})\n", "CREATE (RitaW:Person {name:'Rita Wilson', born:1956})\n", "CREATE (BillPull:Person {name:'Bill Pullman', born:1953})\n", "CREATE (VictorG:Person {name:'Victor Garber', born:1949})\n", "CREATE (RosieO:Person {name:\"Rosie O'Donnell\", born:1962})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Sam Baldwin']}]->(SleeplessInSeattle),\n", "(MegR)-[:ACTED_IN {roles:['Annie Reed']}]->(SleeplessInSeattle),\n", "(RitaW)-[:ACTED_IN {roles:['Suzy']}]->(SleeplessInSeattle),\n", "(BillPull)-[:ACTED_IN {roles:['Walter']}]->(SleeplessInSeattle),\n", "(VictorG)-[:ACTED_IN {roles:['Greg']}]->(SleeplessInSeattle),\n", "(RosieO)-[:ACTED_IN {roles:['Becky']}]->(SleeplessInSeattle),\n", "(NoraE)-[:DIRECTED]->(SleeplessInSeattle)\n", "\n", "CREATE (JoeVersustheVolcano:Movie {title:'Joe Versus the Volcano', released:1990, tagline:'A story of love, lava and burning desire.'})\n", "CREATE (JohnS:Person {name:'John Patrick Stanley', born:1950})\n", "CREATE (Nathan:Person {name:'Nathan Lane', born:1956})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Joe Banks']}]->(JoeVersustheVolcano),\n", "(MegR)-[:ACTED_IN {roles:['DeDe', 'Angelica Graynamore', 'Patricia Graynamore']}]->(JoeVersustheVolcano),\n", "(Nathan)-[:ACTED_IN {roles:['Baw']}]->(JoeVersustheVolcano),\n", "(JohnS)-[:DIRECTED]->(JoeVersustheVolcano)\n", "\n", "CREATE (WhenHarryMetSally:Movie {title:'When Harry Met Sally', released:1998, tagline:'Can two friends sleep together and still love each other in the morning?'})\n", "CREATE (BillyC:Person {name:'Billy Crystal', born:1948})\n", "CREATE (CarrieF:Person {name:'Carrie Fisher', born:1956})\n", "CREATE (BrunoK:Person {name:'Bruno Kirby', born:1949})\n", "CREATE\n", "(BillyC)-[:ACTED_IN {roles:['Harry Burns']}]->(WhenHarryMetSally),\n", "(MegR)-[:ACTED_IN {roles:['Sally Albright']}]->(WhenHarryMetSally),\n", "(CarrieF)-[:ACTED_IN {roles:['Marie']}]->(WhenHarryMetSally),\n", "(BrunoK)-[:ACTED_IN {roles:['Jess']}]->(WhenHarryMetSally),\n", "(RobR)-[:DIRECTED]->(WhenHarryMetSally),\n", "(RobR)-[:PRODUCED]->(WhenHarryMetSally),\n", "(NoraE)-[:PRODUCED]->(WhenHarryMetSally),\n", "(NoraE)-[:WROTE]->(WhenHarryMetSally)\n", "\n", "CREATE (ThatThingYouDo:Movie {title:'That Thing You Do', released:1996, tagline:'In every life there comes a time when that thing you dream becomes that thing you do'})\n", "CREATE (LivT:Person {name:'Liv Tyler', born:1977})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Mr. White']}]->(ThatThingYouDo),\n", "(LivT)-[:ACTED_IN {roles:['Faye Dolan']}]->(ThatThingYouDo),\n", "(Charlize)-[:ACTED_IN {roles:['Tina']}]->(ThatThingYouDo),\n", "(TomH)-[:DIRECTED]->(ThatThingYouDo)\n", "\n", "CREATE (TheReplacements:Movie {title:'The Replacements', released:2000, tagline:'Pain heals, Chicks dig scars... Glory lasts forever'})\n", "CREATE (Brooke:Person {name:'Brooke Langton', born:1970})\n", "CREATE (Gene:Person {name:'Gene Hackman', born:1930})\n", "CREATE (Orlando:Person {name:'Orlando Jones', born:1968})\n", "CREATE (Howard:Person {name:'Howard Deutch', born:1950})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Shane Falco']}]->(TheReplacements),\n", "(Brooke)-[:ACTED_IN {roles:['Annabelle Farrell']}]->(TheReplacements),\n", "(Gene)-[:ACTED_IN {roles:['Jimmy McGinty']}]->(TheReplacements),\n", "(Orlando)-[:ACTED_IN {roles:['Clifford Franklin']}]->(TheReplacements),\n", "(Howard)-[:DIRECTED]->(TheReplacements)\n", "\n", "CREATE (RescueDawn:Movie {title:'RescueDawn', released:2006, tagline:\"Based on the extraordinary true story of one man's fight for freedom\"})\n", "CREATE (ChristianB:Person {name:'Christian Bale', born:1974})\n", "CREATE (ZachG:Person {name:'Zach Grenier', born:1954})\n", "CREATE\n", "(MarshallB)-[:ACTED_IN {roles:['Admiral']}]->(RescueDawn),\n", "(ChristianB)-[:ACTED_IN {roles:['Dieter Dengler']}]->(RescueDawn),\n", "(ZachG)-[:ACTED_IN {roles:['Squad Leader']}]->(RescueDawn),\n", "(SteveZ)-[:ACTED_IN {roles:['Duane']}]->(RescueDawn),\n", "(WernerH)-[:DIRECTED]->(RescueDawn)\n", "\n", "CREATE (TheBirdcage:Movie {title:'The Birdcage', released:1996, tagline:'Come as you are'})\n", "CREATE (MikeN:Person {name:'Mike Nichols', born:1931})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Armand Goldman']}]->(TheBirdcage),\n", "(Nathan)-[:ACTED_IN {roles:['Albert Goldman']}]->(TheBirdcage),\n", "(Gene)-[:ACTED_IN {roles:['Sen. Kevin Keeley']}]->(TheBirdcage),\n", "(MikeN)-[:DIRECTED]->(TheBirdcage)\n", "\n", "CREATE (Unforgiven:Movie {title:'Unforgiven', released:1992, tagline:\"It's a hell of a thing, killing a man\"})\n", "CREATE (RichardH:Person {name:'Richard Harris', born:1930})\n", "CREATE (ClintE:Person {name:'Clint Eastwood', born:1930})\n", "CREATE\n", "(RichardH)-[:ACTED_IN {roles:['English Bob']}]->(Unforgiven),\n", "(ClintE)-[:ACTED_IN {roles:['Bill Munny']}]->(Unforgiven),\n", "(Gene)-[:ACTED_IN {roles:['Little Bill Daggett']}]->(Unforgiven),\n", "(ClintE)-[:DIRECTED]->(Unforgiven)\n", "\n", "CREATE (JohnnyMnemonic:Movie {title:'Johnny Mnemonic', released:1995, tagline:'The hottest data on earth. In the coolest head in town'})\n", "CREATE (Takeshi:Person {name:'Takeshi Kitano', born:1947})\n", "CREATE (Dina:Person {name:'Dina Meyer', born:1968})\n", "CREATE (IceT:Person {name:'Ice-T', born:1958})\n", "CREATE (RobertL:Person {name:'Robert Longo', born:1953})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Johnny Mnemonic']}]->(JohnnyMnemonic),\n", "(Takeshi)-[:ACTED_IN {roles:['Takahashi']}]->(JohnnyMnemonic),\n", "(Dina)-[:ACTED_IN {roles:['Jane']}]->(JohnnyMnemonic),\n", "(IceT)-[:ACTED_IN {roles:['J-Bone']}]->(JohnnyMnemonic),\n", "(RobertL)-[:DIRECTED]->(JohnnyMnemonic)\n", "\n", "CREATE (CloudAtlas:Movie {title:'Cloud Atlas', released:2012, tagline:'Everything is connected'})\n", "CREATE (HalleB:Person {name:'Halle Berry', born:1966})\n", "CREATE (JimB:Person {name:'Jim Broadbent', born:1949})\n", "CREATE (TomT:Person {name:'Tom Tykwer', born:1965})\n", "CREATE (DavidMitchell:Person {name:'David Mitchell', born:1969})\n", "CREATE (StefanArndt:Person {name:'Stefan Arndt', born:1961})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Zachry', 'Dr. Henry Goose', 'Isaac Sachs', 'Dermot Hoggins']}]->(CloudAtlas),\n", "(Hugo)-[:ACTED_IN {roles:['Bill Smoke', 'Haskell Moore', 'Tadeusz Kesselring', 'Nurse Noakes', 'Boardman Mephi', 'Old Georgie']}]->(CloudAtlas),\n", "(HalleB)-[:ACTED_IN {roles:['Luisa Rey', 'Jocasta Ayrs', 'Ovid', 'Meronym']}]->(CloudAtlas),\n", "(JimB)-[:ACTED_IN {roles:['Vyvyan Ayrs', 'Captain Molyneux', 'Timothy Cavendish']}]->(CloudAtlas),\n", "(TomT)-[:DIRECTED]->(CloudAtlas),\n", "(LillyW)-[:DIRECTED]->(CloudAtlas),\n", "(LanaW)-[:DIRECTED]->(CloudAtlas),\n", "(DavidMitchell)-[:WROTE]->(CloudAtlas),\n", "(StefanArndt)-[:PRODUCED]->(CloudAtlas)\n", "\n", "CREATE (TheDaVinciCode:Movie {title:'The Da Vinci Code', released:2006, tagline:'Break The Codes'})\n", "CREATE (IanM:Person {name:'Ian McKellen', born:1939})\n", "CREATE (AudreyT:Person {name:'Audrey Tautou', born:1976})\n", "CREATE (PaulB:Person {name:'Paul Bettany', born:1971})\n", "CREATE (RonH:Person {name:'Ron Howard', born:1954})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Dr. Robert Langdon']}]->(TheDaVinciCode),\n", "(IanM)-[:ACTED_IN {roles:['Sir Leight Teabing']}]->(TheDaVinciCode),\n", "(AudreyT)-[:ACTED_IN {roles:['Sophie Neveu']}]->(TheDaVinciCode),\n", "(PaulB)-[:ACTED_IN {roles:['Silas']}]->(TheDaVinciCode),\n", "(RonH)-[:DIRECTED]->(TheDaVinciCode)\n", "\n", "CREATE (VforVendetta:Movie {title:'V for Vendetta', released:2006, tagline:'Freedom! Forever!'})\n", "CREATE (NatalieP:Person {name:'Natalie Portman', born:1981})\n", "CREATE (StephenR:Person {name:'Stephen Rea', born:1946})\n", "CREATE (JohnH:Person {name:'John Hurt', born:1940})\n", "CREATE (BenM:Person {name: 'Ben Miles', born:1967})\n", "CREATE\n", "(Hugo)-[:ACTED_IN {roles:['V']}]->(VforVendetta),\n", "(NatalieP)-[:ACTED_IN {roles:['Evey Hammond']}]->(VforVendetta),\n", "(StephenR)-[:ACTED_IN {roles:['Eric Finch']}]->(VforVendetta),\n", "(JohnH)-[:ACTED_IN {roles:['High Chancellor Adam Sutler']}]->(VforVendetta),\n", "(BenM)-[:ACTED_IN {roles:['Dascomb']}]->(VforVendetta),\n", "(JamesM)-[:DIRECTED]->(VforVendetta),\n", "(LillyW)-[:PRODUCED]->(VforVendetta),\n", "(LanaW)-[:PRODUCED]->(VforVendetta),\n", "(JoelS)-[:PRODUCED]->(VforVendetta),\n", "(LillyW)-[:WROTE]->(VforVendetta),\n", "(LanaW)-[:WROTE]->(VforVendetta)\n", "\n", "CREATE (SpeedRacer:Movie {title:'Speed Racer', released:2008, tagline:'Speed has no limits'})\n", "CREATE (EmileH:Person {name:'Emile Hirsch', born:1985})\n", "CREATE (JohnG:Person {name:'John Goodman', born:1960})\n", "CREATE (SusanS:Person {name:'Susan Sarandon', born:1946})\n", "CREATE (MatthewF:Person {name:'Matthew Fox', born:1966})\n", "CREATE (ChristinaR:Person {name:'Christina Ricci', born:1980})\n", "CREATE (Rain:Person {name:'Rain', born:1982})\n", "CREATE\n", "(EmileH)-[:ACTED_IN {roles:['Speed Racer']}]->(SpeedRacer),\n", "(JohnG)-[:ACTED_IN {roles:['Pops']}]->(SpeedRacer),\n", "(SusanS)-[:ACTED_IN {roles:['Mom']}]->(SpeedRacer),\n", "(MatthewF)-[:ACTED_IN {roles:['Racer X']}]->(SpeedRacer),\n", "(ChristinaR)-[:ACTED_IN {roles:['Trixie']}]->(SpeedRacer),\n", "(Rain)-[:ACTED_IN {roles:['Taejo Togokahn']}]->(SpeedRacer),\n", "(BenM)-[:ACTED_IN {roles:['Cass Jones']}]->(SpeedRacer),\n", "(LillyW)-[:DIRECTED]->(SpeedRacer),\n", "(LanaW)-[:DIRECTED]->(SpeedRacer),\n", "(LillyW)-[:WROTE]->(SpeedRacer),\n", "(LanaW)-[:WROTE]->(SpeedRacer),\n", "(JoelS)-[:PRODUCED]->(SpeedRacer)\n", "\n", "CREATE (NinjaAssassin:Movie {title:'Ninja Assassin', released:2009, tagline:'Prepare to enter a secret world of assassins'})\n", "CREATE (NaomieH:Person {name:'Naomie Harris'})\n", "CREATE\n", "(Rain)-[:ACTED_IN {roles:['Raizo']}]->(NinjaAssassin),\n", "(NaomieH)-[:ACTED_IN {roles:['Mika Coretti']}]->(NinjaAssassin),\n", "(RickY)-[:ACTED_IN {roles:['Takeshi']}]->(NinjaAssassin),\n", "(BenM)-[:ACTED_IN {roles:['Ryan Maslow']}]->(NinjaAssassin),\n", "(JamesM)-[:DIRECTED]->(NinjaAssassin),\n", "(LillyW)-[:PRODUCED]->(NinjaAssassin),\n", "(LanaW)-[:PRODUCED]->(NinjaAssassin),\n", "(JoelS)-[:PRODUCED]->(NinjaAssassin)\n", "\n", "CREATE (TheGreenMile:Movie {title:'The Green Mile', released:1999, tagline:\"Walk a mile you'll never forget.\"})\n", "CREATE (MichaelD:Person {name:'Michael Clarke Duncan', born:1957})\n", "CREATE (DavidM:Person {name:'David Morse', born:1953})\n", "CREATE (SamR:Person {name:'Sam Rockwell', born:1968})\n", "CREATE (GaryS:Person {name:'Gary Sinise', born:1955})\n", "CREATE (PatriciaC:Person {name:'Patricia Clarkson', born:1959})\n", "CREATE (FrankD:Person {name:'Frank Darabont', born:1959})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Paul Edgecomb']}]->(TheGreenMile),\n", "(MichaelD)-[:ACTED_IN {roles:['John Coffey']}]->(TheGreenMile),\n", "(DavidM)-[:ACTED_IN {roles:['Brutus \"Brutal\" Howell']}]->(TheGreenMile),\n", "(BonnieH)-[:ACTED_IN {roles:['Jan Edgecomb']}]->(TheGreenMile),\n", "(JamesC)-[:ACTED_IN {roles:['Warden Hal Moores']}]->(TheGreenMile),\n", "(SamR)-[:ACTED_IN {roles:['\"Wild Bill\" Wharton']}]->(TheGreenMile),\n", "(GaryS)-[:ACTED_IN {roles:['Burt Hammersmith']}]->(TheGreenMile),\n", "(PatriciaC)-[:ACTED_IN {roles:['Melinda Moores']}]->(TheGreenMile),\n", "(FrankD)-[:DIRECTED]->(TheGreenMile)\n", "\n", "CREATE (FrostNixon:Movie {title:'Frost/Nixon', released:2008, tagline:'400 million people were waiting for the truth.'})\n", "CREATE (FrankL:Person {name:'Frank Langella', born:1938})\n", "CREATE (MichaelS:Person {name:'Michael Sheen', born:1969})\n", "CREATE (OliverP:Person {name:'Oliver Platt', born:1960})\n", "CREATE\n", "(FrankL)-[:ACTED_IN {roles:['Richard Nixon']}]->(FrostNixon),\n", "(MichaelS)-[:ACTED_IN {roles:['David Frost']}]->(FrostNixon),\n", "(KevinB)-[:ACTED_IN {roles:['Jack Brennan']}]->(FrostNixon),\n", "(OliverP)-[:ACTED_IN {roles:['Bob Zelnick']}]->(FrostNixon),\n", "(SamR)-[:ACTED_IN {roles:['James Reston, Jr.']}]->(FrostNixon),\n", "(RonH)-[:DIRECTED]->(FrostNixon)\n", "\n", "CREATE (Hoffa:Movie {title:'Hoffa', released:1992, tagline:\"He didn't want law. He wanted justice.\"})\n", "CREATE (DannyD:Person {name:'Danny DeVito', born:1944})\n", "CREATE (JohnR:Person {name:'John C. Reilly', born:1965})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Hoffa']}]->(Hoffa),\n", "(DannyD)-[:ACTED_IN {roles:['Robert \"Bobby\" Ciaro']}]->(Hoffa),\n", "(JTW)-[:ACTED_IN {roles:['Frank Fitzsimmons']}]->(Hoffa),\n", "(JohnR)-[:ACTED_IN {roles:['Peter \"Pete\" Connelly']}]->(Hoffa),\n", "(DannyD)-[:DIRECTED]->(Hoffa)\n", "\n", "CREATE (Apollo13:Movie {title:'Apollo 13', released:1995, tagline:'Houston, we have a problem.'})\n", "CREATE (EdH:Person {name:'Ed Harris', born:1950})\n", "CREATE (BillPax:Person {name:'Bill Paxton', born:1955})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Jim Lovell']}]->(Apollo13),\n", "(KevinB)-[:ACTED_IN {roles:['Jack Swigert']}]->(Apollo13),\n", "(EdH)-[:ACTED_IN {roles:['Gene Kranz']}]->(Apollo13),\n", "(BillPax)-[:ACTED_IN {roles:['Fred Haise']}]->(Apollo13),\n", "(GaryS)-[:ACTED_IN {roles:['Ken Mattingly']}]->(Apollo13),\n", "(RonH)-[:DIRECTED]->(Apollo13)\n", "\n", "CREATE (Twister:Movie {title:'Twister', released:1996, tagline:\"Don't Breathe. Don't Look Back.\"})\n", "CREATE (PhilipH:Person {name:'Philip Seymour Hoffman', born:1967})\n", "CREATE (JanB:Person {name:'Jan de Bont', born:1943})\n", "CREATE\n", "(BillPax)-[:ACTED_IN {roles:['Bill Harding']}]->(Twister),\n", "(HelenH)-[:ACTED_IN {roles:['Dr. Jo Harding']}]->(Twister),\n", "(ZachG)-[:ACTED_IN {roles:['Eddie']}]->(Twister),\n", "(PhilipH)-[:ACTED_IN {roles:['Dustin \"Dusty\" Davis']}]->(Twister),\n", "(JanB)-[:DIRECTED]->(Twister)\n", "\n", "CREATE (CastAway:Movie {title:'Cast Away', released:2000, tagline:'At the edge of the world, his journey begins.'})\n", "CREATE (RobertZ:Person {name:'Robert Zemeckis', born:1951})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Chuck Noland']}]->(CastAway),\n", "(HelenH)-[:ACTED_IN {roles:['Kelly Frears']}]->(CastAway),\n", "(RobertZ)-[:DIRECTED]->(CastAway)\n", "\n", "CREATE (OneFlewOvertheCuckoosNest:Movie {title:\"One Flew Over the Cuckoo's Nest\", released:1975, tagline:\"If he's crazy, what does that make you?\"})\n", "CREATE (MilosF:Person {name:'Milos Forman', born:1932})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Randle McMurphy']}]->(OneFlewOvertheCuckoosNest),\n", "(DannyD)-[:ACTED_IN {roles:['Martini']}]->(OneFlewOvertheCuckoosNest),\n", "(MilosF)-[:DIRECTED]->(OneFlewOvertheCuckoosNest)\n", "\n", "CREATE (SomethingsGottaGive:Movie {title:\"Something's Gotta Give\", released:2003})\n", "CREATE (DianeK:Person {name:'Diane Keaton', born:1946})\n", "CREATE (NancyM:Person {name:'Nancy Meyers', born:1949})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Harry Sanborn']}]->(SomethingsGottaGive),\n", "(DianeK)-[:ACTED_IN {roles:['Erica Barry']}]->(SomethingsGottaGive),\n", "(Keanu)-[:ACTED_IN {roles:['Julian Mercer']}]->(SomethingsGottaGive),\n", "(NancyM)-[:DIRECTED]->(SomethingsGottaGive),\n", "(NancyM)-[:PRODUCED]->(SomethingsGottaGive),\n", "(NancyM)-[:WROTE]->(SomethingsGottaGive)\n", "\n", "CREATE (BicentennialMan:Movie {title:'Bicentennial Man', released:1999, tagline:\"One robot's 200 year journey to become an ordinary man.\"})\n", "CREATE (ChrisC:Person {name:'Chris Columbus', born:1958})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Andrew Marin']}]->(BicentennialMan),\n", "(OliverP)-[:ACTED_IN {roles:['Rupert Burns']}]->(BicentennialMan),\n", "(ChrisC)-[:DIRECTED]->(BicentennialMan)\n", "\n", "CREATE (CharlieWilsonsWar:Movie {title:\"Charlie Wilson's War\", released:2007, tagline:\"A stiff drink. A little mascara. A lot of nerve. Who said they couldn't bring down the Soviet empire.\"})\n", "CREATE (JuliaR:Person {name:'Julia Roberts', born:1967})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Rep. Charlie Wilson']}]->(CharlieWilsonsWar),\n", "(JuliaR)-[:ACTED_IN {roles:['Joanne Herring']}]->(CharlieWilsonsWar),\n", "(PhilipH)-[:ACTED_IN {roles:['Gust Avrakotos']}]->(CharlieWilsonsWar),\n", "(MikeN)-[:DIRECTED]->(CharlieWilsonsWar)\n", "\n", "CREATE (ThePolarExpress:Movie {title:'The Polar Express', released:2004, tagline:'This Holiday Season... Believe'})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Hero Boy', 'Father', 'Conductor', 'Hobo', 'Scrooge', 'Santa Claus']}]->(ThePolarExpress),\n", "(RobertZ)-[:DIRECTED]->(ThePolarExpress)\n", "\n", "CREATE (ALeagueofTheirOwn:Movie {title:'A League of Their Own', released:1992, tagline:'Once in a lifetime you get a chance to do something different.'})\n", "CREATE (Madonna:Person {name:'Madonna', born:1954})\n", "CREATE (GeenaD:Person {name:'Geena Davis', born:1956})\n", "CREATE (LoriP:Person {name:'Lori Petty', born:1963})\n", "CREATE (PennyM:Person {name:'Penny Marshall', born:1943})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Jimmy Dugan']}]->(ALeagueofTheirOwn),\n", "(GeenaD)-[:ACTED_IN {roles:['Dottie Hinson']}]->(ALeagueofTheirOwn),\n", "(LoriP)-[:ACTED_IN {roles:['Kit Keller']}]->(ALeagueofTheirOwn),\n", "(RosieO)-[:ACTED_IN {roles:['Doris Murphy']}]->(ALeagueofTheirOwn),\n", "(Madonna)-[:ACTED_IN {roles:['\"All the Way\" Mae Mordabito']}]->(ALeagueofTheirOwn),\n", "(BillPax)-[:ACTED_IN {roles:['Bob Hinson']}]->(ALeagueofTheirOwn),\n", "(PennyM)-[:DIRECTED]->(ALeagueofTheirOwn)\n", "\n", "CREATE (PaulBlythe:Person {name:'Paul Blythe'})\n", "CREATE (AngelaScope:Person {name:'Angela Scope'})\n", "CREATE (JessicaThompson:Person {name:'Jessica Thompson'})\n", "CREATE (JamesThompson:Person {name:'James Thompson'})\n", "\n", "CREATE\n", "(JamesThompson)-[:FOLLOWS]->(JessicaThompson),\n", "(AngelaScope)-[:FOLLOWS]->(JessicaThompson),\n", "(PaulBlythe)-[:FOLLOWS]->(AngelaScope)\n", "\n", "CREATE\n", "(JessicaThompson)-[:REVIEWED {summary:'An amazing journey', rating:95}]->(CloudAtlas),\n", "(JessicaThompson)-[:REVIEWED {summary:'Silly, but fun', rating:65}]->(TheReplacements),\n", "(JamesThompson)-[:REVIEWED {summary:'The coolest football movie ever', rating:100}]->(TheReplacements),\n", "(AngelaScope)-[:REVIEWED {summary:'Pretty funny at times', rating:62}]->(TheReplacements),\n", "(JessicaThompson)-[:REVIEWED {summary:'Dark, but compelling', rating:85}]->(Unforgiven),\n", "(JessicaThompson)-[:REVIEWED {summary:\"Slapstick redeemed only by the Robin Williams and Gene Hackman's stellar performances\", rating:45}]->(TheBirdcage),\n", "(JessicaThompson)-[:REVIEWED {summary:'A solid romp', rating:68}]->(TheDaVinciCode),\n", "(JamesThompson)-[:REVIEWED {summary:'Fun, but a little far fetched', rating:65}]->(TheDaVinciCode),\n", "(JessicaThompson)-[:REVIEWED {summary:'You had me at Jerry', rating:92}]->(JerryMaguire)\n", "\n", "WITH TomH as a\n", "MATCH (a)-[:ACTED_IN]->(m)<-[:DIRECTED]-(d) RETURN a,m,d LIMIT 10;\n", "```\n", "\n", "
\n", "\n", "This will create the following graph model\n", "\n", "" ], "metadata": { "id": "AQhqv93Mj0Ss" } }, { "cell_type": "code", "source": [ "%%cypher -u $neo4j_url -us $neo4j_user -pw $neo4j_password\n", "// the following Cypher query is the same as above\n", "// and is required for running the notebook\n", "CREATE (TheMatrix:Movie {title:'The Matrix', released:1999, tagline:'Welcome to the Real World'})\n", "CREATE (Keanu:Person {name:'Keanu Reeves', born:1964})\n", "CREATE (Carrie:Person {name:'Carrie-Anne Moss', born:1967})\n", "CREATE (Laurence:Person {name:'Laurence Fishburne', born:1961})\n", "CREATE (Hugo:Person {name:'Hugo Weaving', born:1960})\n", "CREATE (LillyW:Person {name:'Lilly Wachowski', born:1967})\n", "CREATE (LanaW:Person {name:'Lana Wachowski', born:1965})\n", "CREATE (JoelS:Person {name:'Joel Silver', born:1952})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrix),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrix),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrix),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrix),\n", "(LillyW)-[:DIRECTED]->(TheMatrix),\n", "(LanaW)-[:DIRECTED]->(TheMatrix),\n", "(JoelS)-[:PRODUCED]->(TheMatrix)\n", "\n", "CREATE (Emil:Person {name:\"Emil Eifrem\", born:1978})\n", "CREATE (Emil)-[:ACTED_IN {roles:[\"Emil\"]}]->(TheMatrix)\n", "\n", "CREATE (TheMatrixReloaded:Movie {title:'The Matrix Reloaded', released:2003, tagline:'Free your mind'})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixReloaded),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixReloaded),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixReloaded),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixReloaded),\n", "(LillyW)-[:DIRECTED]->(TheMatrixReloaded),\n", "(LanaW)-[:DIRECTED]->(TheMatrixReloaded),\n", "(JoelS)-[:PRODUCED]->(TheMatrixReloaded)\n", "\n", "CREATE (TheMatrixRevolutions:Movie {title:'The Matrix Revolutions', released:2003, tagline:'Everything that has a beginning has an end'})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixRevolutions),\n", "(Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixRevolutions),\n", "(Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixRevolutions),\n", "(Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixRevolutions),\n", "(LillyW)-[:DIRECTED]->(TheMatrixRevolutions),\n", "(LanaW)-[:DIRECTED]->(TheMatrixRevolutions),\n", "(JoelS)-[:PRODUCED]->(TheMatrixRevolutions)\n", "\n", "CREATE (TheDevilsAdvocate:Movie {title:\"The Devil's Advocate\", released:1997, tagline:'Evil has its winning ways'})\n", "CREATE (Charlize:Person {name:'Charlize Theron', born:1975})\n", "CREATE (Al:Person {name:'Al Pacino', born:1940})\n", "CREATE (Taylor:Person {name:'Taylor Hackford', born:1944})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Kevin Lomax']}]->(TheDevilsAdvocate),\n", "(Charlize)-[:ACTED_IN {roles:['Mary Ann Lomax']}]->(TheDevilsAdvocate),\n", "(Al)-[:ACTED_IN {roles:['John Milton']}]->(TheDevilsAdvocate),\n", "(Taylor)-[:DIRECTED]->(TheDevilsAdvocate)\n", "\n", "CREATE (AFewGoodMen:Movie {title:\"A Few Good Men\", released:1992, tagline:\"In the heart of the nation's capital, in a courthouse of the U.S. government, one man will stop at nothing to keep his honor, and one will stop at nothing to find the truth.\"})\n", "CREATE (TomC:Person {name:'Tom Cruise', born:1962})\n", "CREATE (JackN:Person {name:'Jack Nicholson', born:1937})\n", "CREATE (DemiM:Person {name:'Demi Moore', born:1962})\n", "CREATE (KevinB:Person {name:'Kevin Bacon', born:1958})\n", "CREATE (KieferS:Person {name:'Kiefer Sutherland', born:1966})\n", "CREATE (NoahW:Person {name:'Noah Wyle', born:1971})\n", "CREATE (CubaG:Person {name:'Cuba Gooding Jr.', born:1968})\n", "CREATE (KevinP:Person {name:'Kevin Pollak', born:1957})\n", "CREATE (JTW:Person {name:'J.T. Walsh', born:1943})\n", "CREATE (JamesM:Person {name:'James Marshall', born:1967})\n", "CREATE (ChristopherG:Person {name:'Christopher Guest', born:1948})\n", "CREATE (RobR:Person {name:'Rob Reiner', born:1947})\n", "CREATE (AaronS:Person {name:'Aaron Sorkin', born:1961})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Lt. Daniel Kaffee']}]->(AFewGoodMen),\n", "(JackN)-[:ACTED_IN {roles:['Col. Nathan R. Jessup']}]->(AFewGoodMen),\n", "(DemiM)-[:ACTED_IN {roles:['Lt. Cdr. JoAnne Galloway']}]->(AFewGoodMen),\n", "(KevinB)-[:ACTED_IN {roles:['Capt. Jack Ross']}]->(AFewGoodMen),\n", "(KieferS)-[:ACTED_IN {roles:['Lt. Jonathan Kendrick']}]->(AFewGoodMen),\n", "(NoahW)-[:ACTED_IN {roles:['Cpl. Jeffrey Barnes']}]->(AFewGoodMen),\n", "(CubaG)-[:ACTED_IN {roles:['Cpl. Carl Hammaker']}]->(AFewGoodMen),\n", "(KevinP)-[:ACTED_IN {roles:['Lt. Sam Weinberg']}]->(AFewGoodMen),\n", "(JTW)-[:ACTED_IN {roles:['Lt. Col. Matthew Andrew Markinson']}]->(AFewGoodMen),\n", "(JamesM)-[:ACTED_IN {roles:['Pfc. Louden Downey']}]->(AFewGoodMen),\n", "(ChristopherG)-[:ACTED_IN {roles:['Dr. Stone']}]->(AFewGoodMen),\n", "(AaronS)-[:ACTED_IN {roles:['Man in Bar']}]->(AFewGoodMen),\n", "(RobR)-[:DIRECTED]->(AFewGoodMen),\n", "(AaronS)-[:WROTE]->(AFewGoodMen)\n", "\n", "CREATE (TopGun:Movie {title:\"Top Gun\", released:1986, tagline:'I feel the need, the need for speed.'})\n", "CREATE (KellyM:Person {name:'Kelly McGillis', born:1957})\n", "CREATE (ValK:Person {name:'Val Kilmer', born:1959})\n", "CREATE (AnthonyE:Person {name:'Anthony Edwards', born:1962})\n", "CREATE (TomS:Person {name:'Tom Skerritt', born:1933})\n", "CREATE (MegR:Person {name:'Meg Ryan', born:1961})\n", "CREATE (TonyS:Person {name:'Tony Scott', born:1944})\n", "CREATE (JimC:Person {name:'Jim Cash', born:1941})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Maverick']}]->(TopGun),\n", "(KellyM)-[:ACTED_IN {roles:['Charlie']}]->(TopGun),\n", "(ValK)-[:ACTED_IN {roles:['Iceman']}]->(TopGun),\n", "(AnthonyE)-[:ACTED_IN {roles:['Goose']}]->(TopGun),\n", "(TomS)-[:ACTED_IN {roles:['Viper']}]->(TopGun),\n", "(MegR)-[:ACTED_IN {roles:['Carole']}]->(TopGun),\n", "(TonyS)-[:DIRECTED]->(TopGun),\n", "(JimC)-[:WROTE]->(TopGun)\n", "\n", "CREATE (JerryMaguire:Movie {title:'Jerry Maguire', released:2000, tagline:'The rest of his life begins now.'})\n", "CREATE (ReneeZ:Person {name:'Renee Zellweger', born:1969})\n", "CREATE (KellyP:Person {name:'Kelly Preston', born:1962})\n", "CREATE (JerryO:Person {name:\"Jerry O'Connell\", born:1974})\n", "CREATE (JayM:Person {name:'Jay Mohr', born:1970})\n", "CREATE (BonnieH:Person {name:'Bonnie Hunt', born:1961})\n", "CREATE (ReginaK:Person {name:'Regina King', born:1971})\n", "CREATE (JonathanL:Person {name:'Jonathan Lipnicki', born:1996})\n", "CREATE (CameronC:Person {name:'Cameron Crowe', born:1957})\n", "CREATE\n", "(TomC)-[:ACTED_IN {roles:['Jerry Maguire']}]->(JerryMaguire),\n", "(CubaG)-[:ACTED_IN {roles:['Rod Tidwell']}]->(JerryMaguire),\n", "(ReneeZ)-[:ACTED_IN {roles:['Dorothy Boyd']}]->(JerryMaguire),\n", "(KellyP)-[:ACTED_IN {roles:['Avery Bishop']}]->(JerryMaguire),\n", "(JerryO)-[:ACTED_IN {roles:['Frank Cushman']}]->(JerryMaguire),\n", "(JayM)-[:ACTED_IN {roles:['Bob Sugar']}]->(JerryMaguire),\n", "(BonnieH)-[:ACTED_IN {roles:['Laurel Boyd']}]->(JerryMaguire),\n", "(ReginaK)-[:ACTED_IN {roles:['Marcee Tidwell']}]->(JerryMaguire),\n", "(JonathanL)-[:ACTED_IN {roles:['Ray Boyd']}]->(JerryMaguire),\n", "(CameronC)-[:DIRECTED]->(JerryMaguire),\n", "(CameronC)-[:PRODUCED]->(JerryMaguire),\n", "(CameronC)-[:WROTE]->(JerryMaguire)\n", "\n", "CREATE (StandByMe:Movie {title:\"Stand By Me\", released:1986, tagline:\"For some, it's the last real taste of innocence, and the first real taste of life. But for everyone, it's the time that memories are made of.\"})\n", "CREATE (RiverP:Person {name:'River Phoenix', born:1970})\n", "CREATE (CoreyF:Person {name:'Corey Feldman', born:1971})\n", "CREATE (WilW:Person {name:'Wil Wheaton', born:1972})\n", "CREATE (JohnC:Person {name:'John Cusack', born:1966})\n", "CREATE (MarshallB:Person {name:'Marshall Bell', born:1942})\n", "CREATE\n", "(WilW)-[:ACTED_IN {roles:['Gordie Lachance']}]->(StandByMe),\n", "(RiverP)-[:ACTED_IN {roles:['Chris Chambers']}]->(StandByMe),\n", "(JerryO)-[:ACTED_IN {roles:['Vern Tessio']}]->(StandByMe),\n", "(CoreyF)-[:ACTED_IN {roles:['Teddy Duchamp']}]->(StandByMe),\n", "(JohnC)-[:ACTED_IN {roles:['Denny Lachance']}]->(StandByMe),\n", "(KieferS)-[:ACTED_IN {roles:['Ace Merrill']}]->(StandByMe),\n", "(MarshallB)-[:ACTED_IN {roles:['Mr. Lachance']}]->(StandByMe),\n", "(RobR)-[:DIRECTED]->(StandByMe)\n", "\n", "CREATE (AsGoodAsItGets:Movie {title:'As Good as It Gets', released:1997, tagline:'A comedy from the heart that goes for the throat.'})\n", "CREATE (HelenH:Person {name:'Helen Hunt', born:1963})\n", "CREATE (GregK:Person {name:'Greg Kinnear', born:1963})\n", "CREATE (JamesB:Person {name:'James L. Brooks', born:1940})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Melvin Udall']}]->(AsGoodAsItGets),\n", "(HelenH)-[:ACTED_IN {roles:['Carol Connelly']}]->(AsGoodAsItGets),\n", "(GregK)-[:ACTED_IN {roles:['Simon Bishop']}]->(AsGoodAsItGets),\n", "(CubaG)-[:ACTED_IN {roles:['Frank Sachs']}]->(AsGoodAsItGets),\n", "(JamesB)-[:DIRECTED]->(AsGoodAsItGets)\n", "\n", "CREATE (WhatDreamsMayCome:Movie {title:'What Dreams May Come', released:1998, tagline:'After life there is more. The end is just the beginning.'})\n", "CREATE (AnnabellaS:Person {name:'Annabella Sciorra', born:1960})\n", "CREATE (MaxS:Person {name:'Max von Sydow', born:1929})\n", "CREATE (WernerH:Person {name:'Werner Herzog', born:1942})\n", "CREATE (Robin:Person {name:'Robin Williams', born:1951})\n", "CREATE (VincentW:Person {name:'Vincent Ward', born:1956})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Chris Nielsen']}]->(WhatDreamsMayCome),\n", "(CubaG)-[:ACTED_IN {roles:['Albert Lewis']}]->(WhatDreamsMayCome),\n", "(AnnabellaS)-[:ACTED_IN {roles:['Annie Collins-Nielsen']}]->(WhatDreamsMayCome),\n", "(MaxS)-[:ACTED_IN {roles:['The Tracker']}]->(WhatDreamsMayCome),\n", "(WernerH)-[:ACTED_IN {roles:['The Face']}]->(WhatDreamsMayCome),\n", "(VincentW)-[:DIRECTED]->(WhatDreamsMayCome)\n", "\n", "CREATE (SnowFallingonCedars:Movie {title:'Snow Falling on Cedars', released:1999, tagline:'First loves last. Forever.'})\n", "CREATE (EthanH:Person {name:'Ethan Hawke', born:1970})\n", "CREATE (RickY:Person {name:'Rick Yune', born:1971})\n", "CREATE (JamesC:Person {name:'James Cromwell', born:1940})\n", "CREATE (ScottH:Person {name:'Scott Hicks', born:1953})\n", "CREATE\n", "(EthanH)-[:ACTED_IN {roles:['Ishmael Chambers']}]->(SnowFallingonCedars),\n", "(RickY)-[:ACTED_IN {roles:['Kazuo Miyamoto']}]->(SnowFallingonCedars),\n", "(MaxS)-[:ACTED_IN {roles:['Nels Gudmundsson']}]->(SnowFallingonCedars),\n", "(JamesC)-[:ACTED_IN {roles:['Judge Fielding']}]->(SnowFallingonCedars),\n", "(ScottH)-[:DIRECTED]->(SnowFallingonCedars)\n", "\n", "CREATE (YouveGotMail:Movie {title:\"You've Got Mail\", released:1998, tagline:'At odds in life... in love on-line.'})\n", "CREATE (ParkerP:Person {name:'Parker Posey', born:1968})\n", "CREATE (DaveC:Person {name:'Dave Chappelle', born:1973})\n", "CREATE (SteveZ:Person {name:'Steve Zahn', born:1967})\n", "CREATE (TomH:Person {name:'Tom Hanks', born:1956})\n", "CREATE (NoraE:Person {name:'Nora Ephron', born:1941})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Joe Fox']}]->(YouveGotMail),\n", "(MegR)-[:ACTED_IN {roles:['Kathleen Kelly']}]->(YouveGotMail),\n", "(GregK)-[:ACTED_IN {roles:['Frank Navasky']}]->(YouveGotMail),\n", "(ParkerP)-[:ACTED_IN {roles:['Patricia Eden']}]->(YouveGotMail),\n", "(DaveC)-[:ACTED_IN {roles:['Kevin Jackson']}]->(YouveGotMail),\n", "(SteveZ)-[:ACTED_IN {roles:['George Pappas']}]->(YouveGotMail),\n", "(NoraE)-[:DIRECTED]->(YouveGotMail)\n", "\n", "CREATE (SleeplessInSeattle:Movie {title:'Sleepless in Seattle', released:1993, tagline:'What if someone you never met, someone you never saw, someone you never knew was the only someone for you?'})\n", "CREATE (RitaW:Person {name:'Rita Wilson', born:1956})\n", "CREATE (BillPull:Person {name:'Bill Pullman', born:1953})\n", "CREATE (VictorG:Person {name:'Victor Garber', born:1949})\n", "CREATE (RosieO:Person {name:\"Rosie O'Donnell\", born:1962})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Sam Baldwin']}]->(SleeplessInSeattle),\n", "(MegR)-[:ACTED_IN {roles:['Annie Reed']}]->(SleeplessInSeattle),\n", "(RitaW)-[:ACTED_IN {roles:['Suzy']}]->(SleeplessInSeattle),\n", "(BillPull)-[:ACTED_IN {roles:['Walter']}]->(SleeplessInSeattle),\n", "(VictorG)-[:ACTED_IN {roles:['Greg']}]->(SleeplessInSeattle),\n", "(RosieO)-[:ACTED_IN {roles:['Becky']}]->(SleeplessInSeattle),\n", "(NoraE)-[:DIRECTED]->(SleeplessInSeattle)\n", "\n", "CREATE (JoeVersustheVolcano:Movie {title:'Joe Versus the Volcano', released:1990, tagline:'A story of love, lava and burning desire.'})\n", "CREATE (JohnS:Person {name:'John Patrick Stanley', born:1950})\n", "CREATE (Nathan:Person {name:'Nathan Lane', born:1956})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Joe Banks']}]->(JoeVersustheVolcano),\n", "(MegR)-[:ACTED_IN {roles:['DeDe', 'Angelica Graynamore', 'Patricia Graynamore']}]->(JoeVersustheVolcano),\n", "(Nathan)-[:ACTED_IN {roles:['Baw']}]->(JoeVersustheVolcano),\n", "(JohnS)-[:DIRECTED]->(JoeVersustheVolcano)\n", "\n", "CREATE (WhenHarryMetSally:Movie {title:'When Harry Met Sally', released:1998, tagline:'Can two friends sleep together and still love each other in the morning?'})\n", "CREATE (BillyC:Person {name:'Billy Crystal', born:1948})\n", "CREATE (CarrieF:Person {name:'Carrie Fisher', born:1956})\n", "CREATE (BrunoK:Person {name:'Bruno Kirby', born:1949})\n", "CREATE\n", "(BillyC)-[:ACTED_IN {roles:['Harry Burns']}]->(WhenHarryMetSally),\n", "(MegR)-[:ACTED_IN {roles:['Sally Albright']}]->(WhenHarryMetSally),\n", "(CarrieF)-[:ACTED_IN {roles:['Marie']}]->(WhenHarryMetSally),\n", "(BrunoK)-[:ACTED_IN {roles:['Jess']}]->(WhenHarryMetSally),\n", "(RobR)-[:DIRECTED]->(WhenHarryMetSally),\n", "(RobR)-[:PRODUCED]->(WhenHarryMetSally),\n", "(NoraE)-[:PRODUCED]->(WhenHarryMetSally),\n", "(NoraE)-[:WROTE]->(WhenHarryMetSally)\n", "\n", "CREATE (ThatThingYouDo:Movie {title:'That Thing You Do', released:1996, tagline:'In every life there comes a time when that thing you dream becomes that thing you do'})\n", "CREATE (LivT:Person {name:'Liv Tyler', born:1977})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Mr. White']}]->(ThatThingYouDo),\n", "(LivT)-[:ACTED_IN {roles:['Faye Dolan']}]->(ThatThingYouDo),\n", "(Charlize)-[:ACTED_IN {roles:['Tina']}]->(ThatThingYouDo),\n", "(TomH)-[:DIRECTED]->(ThatThingYouDo)\n", "\n", "CREATE (TheReplacements:Movie {title:'The Replacements', released:2000, tagline:'Pain heals, Chicks dig scars... Glory lasts forever'})\n", "CREATE (Brooke:Person {name:'Brooke Langton', born:1970})\n", "CREATE (Gene:Person {name:'Gene Hackman', born:1930})\n", "CREATE (Orlando:Person {name:'Orlando Jones', born:1968})\n", "CREATE (Howard:Person {name:'Howard Deutch', born:1950})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Shane Falco']}]->(TheReplacements),\n", "(Brooke)-[:ACTED_IN {roles:['Annabelle Farrell']}]->(TheReplacements),\n", "(Gene)-[:ACTED_IN {roles:['Jimmy McGinty']}]->(TheReplacements),\n", "(Orlando)-[:ACTED_IN {roles:['Clifford Franklin']}]->(TheReplacements),\n", "(Howard)-[:DIRECTED]->(TheReplacements)\n", "\n", "CREATE (RescueDawn:Movie {title:'RescueDawn', released:2006, tagline:\"Based on the extraordinary true story of one man's fight for freedom\"})\n", "CREATE (ChristianB:Person {name:'Christian Bale', born:1974})\n", "CREATE (ZachG:Person {name:'Zach Grenier', born:1954})\n", "CREATE\n", "(MarshallB)-[:ACTED_IN {roles:['Admiral']}]->(RescueDawn),\n", "(ChristianB)-[:ACTED_IN {roles:['Dieter Dengler']}]->(RescueDawn),\n", "(ZachG)-[:ACTED_IN {roles:['Squad Leader']}]->(RescueDawn),\n", "(SteveZ)-[:ACTED_IN {roles:['Duane']}]->(RescueDawn),\n", "(WernerH)-[:DIRECTED]->(RescueDawn)\n", "\n", "CREATE (TheBirdcage:Movie {title:'The Birdcage', released:1996, tagline:'Come as you are'})\n", "CREATE (MikeN:Person {name:'Mike Nichols', born:1931})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Armand Goldman']}]->(TheBirdcage),\n", "(Nathan)-[:ACTED_IN {roles:['Albert Goldman']}]->(TheBirdcage),\n", "(Gene)-[:ACTED_IN {roles:['Sen. Kevin Keeley']}]->(TheBirdcage),\n", "(MikeN)-[:DIRECTED]->(TheBirdcage)\n", "\n", "CREATE (Unforgiven:Movie {title:'Unforgiven', released:1992, tagline:\"It's a hell of a thing, killing a man\"})\n", "CREATE (RichardH:Person {name:'Richard Harris', born:1930})\n", "CREATE (ClintE:Person {name:'Clint Eastwood', born:1930})\n", "CREATE\n", "(RichardH)-[:ACTED_IN {roles:['English Bob']}]->(Unforgiven),\n", "(ClintE)-[:ACTED_IN {roles:['Bill Munny']}]->(Unforgiven),\n", "(Gene)-[:ACTED_IN {roles:['Little Bill Daggett']}]->(Unforgiven),\n", "(ClintE)-[:DIRECTED]->(Unforgiven)\n", "\n", "CREATE (JohnnyMnemonic:Movie {title:'Johnny Mnemonic', released:1995, tagline:'The hottest data on earth. In the coolest head in town'})\n", "CREATE (Takeshi:Person {name:'Takeshi Kitano', born:1947})\n", "CREATE (Dina:Person {name:'Dina Meyer', born:1968})\n", "CREATE (IceT:Person {name:'Ice-T', born:1958})\n", "CREATE (RobertL:Person {name:'Robert Longo', born:1953})\n", "CREATE\n", "(Keanu)-[:ACTED_IN {roles:['Johnny Mnemonic']}]->(JohnnyMnemonic),\n", "(Takeshi)-[:ACTED_IN {roles:['Takahashi']}]->(JohnnyMnemonic),\n", "(Dina)-[:ACTED_IN {roles:['Jane']}]->(JohnnyMnemonic),\n", "(IceT)-[:ACTED_IN {roles:['J-Bone']}]->(JohnnyMnemonic),\n", "(RobertL)-[:DIRECTED]->(JohnnyMnemonic)\n", "\n", "CREATE (CloudAtlas:Movie {title:'Cloud Atlas', released:2012, tagline:'Everything is connected'})\n", "CREATE (HalleB:Person {name:'Halle Berry', born:1966})\n", "CREATE (JimB:Person {name:'Jim Broadbent', born:1949})\n", "CREATE (TomT:Person {name:'Tom Tykwer', born:1965})\n", "CREATE (DavidMitchell:Person {name:'David Mitchell', born:1969})\n", "CREATE (StefanArndt:Person {name:'Stefan Arndt', born:1961})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Zachry', 'Dr. Henry Goose', 'Isaac Sachs', 'Dermot Hoggins']}]->(CloudAtlas),\n", "(Hugo)-[:ACTED_IN {roles:['Bill Smoke', 'Haskell Moore', 'Tadeusz Kesselring', 'Nurse Noakes', 'Boardman Mephi', 'Old Georgie']}]->(CloudAtlas),\n", "(HalleB)-[:ACTED_IN {roles:['Luisa Rey', 'Jocasta Ayrs', 'Ovid', 'Meronym']}]->(CloudAtlas),\n", "(JimB)-[:ACTED_IN {roles:['Vyvyan Ayrs', 'Captain Molyneux', 'Timothy Cavendish']}]->(CloudAtlas),\n", "(TomT)-[:DIRECTED]->(CloudAtlas),\n", "(LillyW)-[:DIRECTED]->(CloudAtlas),\n", "(LanaW)-[:DIRECTED]->(CloudAtlas),\n", "(DavidMitchell)-[:WROTE]->(CloudAtlas),\n", "(StefanArndt)-[:PRODUCED]->(CloudAtlas)\n", "\n", "CREATE (TheDaVinciCode:Movie {title:'The Da Vinci Code', released:2006, tagline:'Break The Codes'})\n", "CREATE (IanM:Person {name:'Ian McKellen', born:1939})\n", "CREATE (AudreyT:Person {name:'Audrey Tautou', born:1976})\n", "CREATE (PaulB:Person {name:'Paul Bettany', born:1971})\n", "CREATE (RonH:Person {name:'Ron Howard', born:1954})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Dr. Robert Langdon']}]->(TheDaVinciCode),\n", "(IanM)-[:ACTED_IN {roles:['Sir Leight Teabing']}]->(TheDaVinciCode),\n", "(AudreyT)-[:ACTED_IN {roles:['Sophie Neveu']}]->(TheDaVinciCode),\n", "(PaulB)-[:ACTED_IN {roles:['Silas']}]->(TheDaVinciCode),\n", "(RonH)-[:DIRECTED]->(TheDaVinciCode)\n", "\n", "CREATE (VforVendetta:Movie {title:'V for Vendetta', released:2006, tagline:'Freedom! Forever!'})\n", "CREATE (NatalieP:Person {name:'Natalie Portman', born:1981})\n", "CREATE (StephenR:Person {name:'Stephen Rea', born:1946})\n", "CREATE (JohnH:Person {name:'John Hurt', born:1940})\n", "CREATE (BenM:Person {name: 'Ben Miles', born:1967})\n", "CREATE\n", "(Hugo)-[:ACTED_IN {roles:['V']}]->(VforVendetta),\n", "(NatalieP)-[:ACTED_IN {roles:['Evey Hammond']}]->(VforVendetta),\n", "(StephenR)-[:ACTED_IN {roles:['Eric Finch']}]->(VforVendetta),\n", "(JohnH)-[:ACTED_IN {roles:['High Chancellor Adam Sutler']}]->(VforVendetta),\n", "(BenM)-[:ACTED_IN {roles:['Dascomb']}]->(VforVendetta),\n", "(JamesM)-[:DIRECTED]->(VforVendetta),\n", "(LillyW)-[:PRODUCED]->(VforVendetta),\n", "(LanaW)-[:PRODUCED]->(VforVendetta),\n", "(JoelS)-[:PRODUCED]->(VforVendetta),\n", "(LillyW)-[:WROTE]->(VforVendetta),\n", "(LanaW)-[:WROTE]->(VforVendetta)\n", "\n", "CREATE (SpeedRacer:Movie {title:'Speed Racer', released:2008, tagline:'Speed has no limits'})\n", "CREATE (EmileH:Person {name:'Emile Hirsch', born:1985})\n", "CREATE (JohnG:Person {name:'John Goodman', born:1960})\n", "CREATE (SusanS:Person {name:'Susan Sarandon', born:1946})\n", "CREATE (MatthewF:Person {name:'Matthew Fox', born:1966})\n", "CREATE (ChristinaR:Person {name:'Christina Ricci', born:1980})\n", "CREATE (Rain:Person {name:'Rain', born:1982})\n", "CREATE\n", "(EmileH)-[:ACTED_IN {roles:['Speed Racer']}]->(SpeedRacer),\n", "(JohnG)-[:ACTED_IN {roles:['Pops']}]->(SpeedRacer),\n", "(SusanS)-[:ACTED_IN {roles:['Mom']}]->(SpeedRacer),\n", "(MatthewF)-[:ACTED_IN {roles:['Racer X']}]->(SpeedRacer),\n", "(ChristinaR)-[:ACTED_IN {roles:['Trixie']}]->(SpeedRacer),\n", "(Rain)-[:ACTED_IN {roles:['Taejo Togokahn']}]->(SpeedRacer),\n", "(BenM)-[:ACTED_IN {roles:['Cass Jones']}]->(SpeedRacer),\n", "(LillyW)-[:DIRECTED]->(SpeedRacer),\n", "(LanaW)-[:DIRECTED]->(SpeedRacer),\n", "(LillyW)-[:WROTE]->(SpeedRacer),\n", "(LanaW)-[:WROTE]->(SpeedRacer),\n", "(JoelS)-[:PRODUCED]->(SpeedRacer)\n", "\n", "CREATE (NinjaAssassin:Movie {title:'Ninja Assassin', released:2009, tagline:'Prepare to enter a secret world of assassins'})\n", "CREATE (NaomieH:Person {name:'Naomie Harris'})\n", "CREATE\n", "(Rain)-[:ACTED_IN {roles:['Raizo']}]->(NinjaAssassin),\n", "(NaomieH)-[:ACTED_IN {roles:['Mika Coretti']}]->(NinjaAssassin),\n", "(RickY)-[:ACTED_IN {roles:['Takeshi']}]->(NinjaAssassin),\n", "(BenM)-[:ACTED_IN {roles:['Ryan Maslow']}]->(NinjaAssassin),\n", "(JamesM)-[:DIRECTED]->(NinjaAssassin),\n", "(LillyW)-[:PRODUCED]->(NinjaAssassin),\n", "(LanaW)-[:PRODUCED]->(NinjaAssassin),\n", "(JoelS)-[:PRODUCED]->(NinjaAssassin)\n", "\n", "CREATE (TheGreenMile:Movie {title:'The Green Mile', released:1999, tagline:\"Walk a mile you'll never forget.\"})\n", "CREATE (MichaelD:Person {name:'Michael Clarke Duncan', born:1957})\n", "CREATE (DavidM:Person {name:'David Morse', born:1953})\n", "CREATE (SamR:Person {name:'Sam Rockwell', born:1968})\n", "CREATE (GaryS:Person {name:'Gary Sinise', born:1955})\n", "CREATE (PatriciaC:Person {name:'Patricia Clarkson', born:1959})\n", "CREATE (FrankD:Person {name:'Frank Darabont', born:1959})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Paul Edgecomb']}]->(TheGreenMile),\n", "(MichaelD)-[:ACTED_IN {roles:['John Coffey']}]->(TheGreenMile),\n", "(DavidM)-[:ACTED_IN {roles:['Brutus \"Brutal\" Howell']}]->(TheGreenMile),\n", "(BonnieH)-[:ACTED_IN {roles:['Jan Edgecomb']}]->(TheGreenMile),\n", "(JamesC)-[:ACTED_IN {roles:['Warden Hal Moores']}]->(TheGreenMile),\n", "(SamR)-[:ACTED_IN {roles:['\"Wild Bill\" Wharton']}]->(TheGreenMile),\n", "(GaryS)-[:ACTED_IN {roles:['Burt Hammersmith']}]->(TheGreenMile),\n", "(PatriciaC)-[:ACTED_IN {roles:['Melinda Moores']}]->(TheGreenMile),\n", "(FrankD)-[:DIRECTED]->(TheGreenMile)\n", "\n", "CREATE (FrostNixon:Movie {title:'Frost/Nixon', released:2008, tagline:'400 million people were waiting for the truth.'})\n", "CREATE (FrankL:Person {name:'Frank Langella', born:1938})\n", "CREATE (MichaelS:Person {name:'Michael Sheen', born:1969})\n", "CREATE (OliverP:Person {name:'Oliver Platt', born:1960})\n", "CREATE\n", "(FrankL)-[:ACTED_IN {roles:['Richard Nixon']}]->(FrostNixon),\n", "(MichaelS)-[:ACTED_IN {roles:['David Frost']}]->(FrostNixon),\n", "(KevinB)-[:ACTED_IN {roles:['Jack Brennan']}]->(FrostNixon),\n", "(OliverP)-[:ACTED_IN {roles:['Bob Zelnick']}]->(FrostNixon),\n", "(SamR)-[:ACTED_IN {roles:['James Reston, Jr.']}]->(FrostNixon),\n", "(RonH)-[:DIRECTED]->(FrostNixon)\n", "\n", "CREATE (Hoffa:Movie {title:'Hoffa', released:1992, tagline:\"He didn't want law. He wanted justice.\"})\n", "CREATE (DannyD:Person {name:'Danny DeVito', born:1944})\n", "CREATE (JohnR:Person {name:'John C. Reilly', born:1965})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Hoffa']}]->(Hoffa),\n", "(DannyD)-[:ACTED_IN {roles:['Robert \"Bobby\" Ciaro']}]->(Hoffa),\n", "(JTW)-[:ACTED_IN {roles:['Frank Fitzsimmons']}]->(Hoffa),\n", "(JohnR)-[:ACTED_IN {roles:['Peter \"Pete\" Connelly']}]->(Hoffa),\n", "(DannyD)-[:DIRECTED]->(Hoffa)\n", "\n", "CREATE (Apollo13:Movie {title:'Apollo 13', released:1995, tagline:'Houston, we have a problem.'})\n", "CREATE (EdH:Person {name:'Ed Harris', born:1950})\n", "CREATE (BillPax:Person {name:'Bill Paxton', born:1955})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Jim Lovell']}]->(Apollo13),\n", "(KevinB)-[:ACTED_IN {roles:['Jack Swigert']}]->(Apollo13),\n", "(EdH)-[:ACTED_IN {roles:['Gene Kranz']}]->(Apollo13),\n", "(BillPax)-[:ACTED_IN {roles:['Fred Haise']}]->(Apollo13),\n", "(GaryS)-[:ACTED_IN {roles:['Ken Mattingly']}]->(Apollo13),\n", "(RonH)-[:DIRECTED]->(Apollo13)\n", "\n", "CREATE (Twister:Movie {title:'Twister', released:1996, tagline:\"Don't Breathe. Don't Look Back.\"})\n", "CREATE (PhilipH:Person {name:'Philip Seymour Hoffman', born:1967})\n", "CREATE (JanB:Person {name:'Jan de Bont', born:1943})\n", "CREATE\n", "(BillPax)-[:ACTED_IN {roles:['Bill Harding']}]->(Twister),\n", "(HelenH)-[:ACTED_IN {roles:['Dr. Jo Harding']}]->(Twister),\n", "(ZachG)-[:ACTED_IN {roles:['Eddie']}]->(Twister),\n", "(PhilipH)-[:ACTED_IN {roles:['Dustin \"Dusty\" Davis']}]->(Twister),\n", "(JanB)-[:DIRECTED]->(Twister)\n", "\n", "CREATE (CastAway:Movie {title:'Cast Away', released:2000, tagline:'At the edge of the world, his journey begins.'})\n", "CREATE (RobertZ:Person {name:'Robert Zemeckis', born:1951})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Chuck Noland']}]->(CastAway),\n", "(HelenH)-[:ACTED_IN {roles:['Kelly Frears']}]->(CastAway),\n", "(RobertZ)-[:DIRECTED]->(CastAway)\n", "\n", "CREATE (OneFlewOvertheCuckoosNest:Movie {title:\"One Flew Over the Cuckoo's Nest\", released:1975, tagline:\"If he's crazy, what does that make you?\"})\n", "CREATE (MilosF:Person {name:'Milos Forman', born:1932})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Randle McMurphy']}]->(OneFlewOvertheCuckoosNest),\n", "(DannyD)-[:ACTED_IN {roles:['Martini']}]->(OneFlewOvertheCuckoosNest),\n", "(MilosF)-[:DIRECTED]->(OneFlewOvertheCuckoosNest)\n", "\n", "CREATE (SomethingsGottaGive:Movie {title:\"Something's Gotta Give\", released:2003})\n", "CREATE (DianeK:Person {name:'Diane Keaton', born:1946})\n", "CREATE (NancyM:Person {name:'Nancy Meyers', born:1949})\n", "CREATE\n", "(JackN)-[:ACTED_IN {roles:['Harry Sanborn']}]->(SomethingsGottaGive),\n", "(DianeK)-[:ACTED_IN {roles:['Erica Barry']}]->(SomethingsGottaGive),\n", "(Keanu)-[:ACTED_IN {roles:['Julian Mercer']}]->(SomethingsGottaGive),\n", "(NancyM)-[:DIRECTED]->(SomethingsGottaGive),\n", "(NancyM)-[:PRODUCED]->(SomethingsGottaGive),\n", "(NancyM)-[:WROTE]->(SomethingsGottaGive)\n", "\n", "CREATE (BicentennialMan:Movie {title:'Bicentennial Man', released:1999, tagline:\"One robot's 200 year journey to become an ordinary man.\"})\n", "CREATE (ChrisC:Person {name:'Chris Columbus', born:1958})\n", "CREATE\n", "(Robin)-[:ACTED_IN {roles:['Andrew Marin']}]->(BicentennialMan),\n", "(OliverP)-[:ACTED_IN {roles:['Rupert Burns']}]->(BicentennialMan),\n", "(ChrisC)-[:DIRECTED]->(BicentennialMan)\n", "\n", "CREATE (CharlieWilsonsWar:Movie {title:\"Charlie Wilson's War\", released:2007, tagline:\"A stiff drink. A little mascara. A lot of nerve. Who said they couldn't bring down the Soviet empire.\"})\n", "CREATE (JuliaR:Person {name:'Julia Roberts', born:1967})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Rep. Charlie Wilson']}]->(CharlieWilsonsWar),\n", "(JuliaR)-[:ACTED_IN {roles:['Joanne Herring']}]->(CharlieWilsonsWar),\n", "(PhilipH)-[:ACTED_IN {roles:['Gust Avrakotos']}]->(CharlieWilsonsWar),\n", "(MikeN)-[:DIRECTED]->(CharlieWilsonsWar)\n", "\n", "CREATE (ThePolarExpress:Movie {title:'The Polar Express', released:2004, tagline:'This Holiday Season... Believe'})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Hero Boy', 'Father', 'Conductor', 'Hobo', 'Scrooge', 'Santa Claus']}]->(ThePolarExpress),\n", "(RobertZ)-[:DIRECTED]->(ThePolarExpress)\n", "\n", "CREATE (ALeagueofTheirOwn:Movie {title:'A League of Their Own', released:1992, tagline:'Once in a lifetime you get a chance to do something different.'})\n", "CREATE (Madonna:Person {name:'Madonna', born:1954})\n", "CREATE (GeenaD:Person {name:'Geena Davis', born:1956})\n", "CREATE (LoriP:Person {name:'Lori Petty', born:1963})\n", "CREATE (PennyM:Person {name:'Penny Marshall', born:1943})\n", "CREATE\n", "(TomH)-[:ACTED_IN {roles:['Jimmy Dugan']}]->(ALeagueofTheirOwn),\n", "(GeenaD)-[:ACTED_IN {roles:['Dottie Hinson']}]->(ALeagueofTheirOwn),\n", "(LoriP)-[:ACTED_IN {roles:['Kit Keller']}]->(ALeagueofTheirOwn),\n", "(RosieO)-[:ACTED_IN {roles:['Doris Murphy']}]->(ALeagueofTheirOwn),\n", "(Madonna)-[:ACTED_IN {roles:['\"All the Way\" Mae Mordabito']}]->(ALeagueofTheirOwn),\n", "(BillPax)-[:ACTED_IN {roles:['Bob Hinson']}]->(ALeagueofTheirOwn),\n", "(PennyM)-[:DIRECTED]->(ALeagueofTheirOwn)\n", "\n", "CREATE (PaulBlythe:Person {name:'Paul Blythe'})\n", "CREATE (AngelaScope:Person {name:'Angela Scope'})\n", "CREATE (JessicaThompson:Person {name:'Jessica Thompson'})\n", "CREATE (JamesThompson:Person {name:'James Thompson'})\n", "\n", "CREATE\n", "(JamesThompson)-[:FOLLOWS]->(JessicaThompson),\n", "(AngelaScope)-[:FOLLOWS]->(JessicaThompson),\n", "(PaulBlythe)-[:FOLLOWS]->(AngelaScope)\n", "\n", "CREATE\n", "(JessicaThompson)-[:REVIEWED {summary:'An amazing journey', rating:95}]->(CloudAtlas),\n", "(JessicaThompson)-[:REVIEWED {summary:'Silly, but fun', rating:65}]->(TheReplacements),\n", "(JamesThompson)-[:REVIEWED {summary:'The coolest football movie ever', rating:100}]->(TheReplacements),\n", "(AngelaScope)-[:REVIEWED {summary:'Pretty funny at times', rating:62}]->(TheReplacements),\n", "(JessicaThompson)-[:REVIEWED {summary:'Dark, but compelling', rating:85}]->(Unforgiven),\n", "(JessicaThompson)-[:REVIEWED {summary:\"Slapstick redeemed only by the Robin Williams and Gene Hackman's stellar performances\", rating:45}]->(TheBirdcage),\n", "(JessicaThompson)-[:REVIEWED {summary:'A solid romp', rating:68}]->(TheDaVinciCode),\n", "(JamesThompson)-[:REVIEWED {summary:'Fun, but a little far fetched', rating:65}]->(TheDaVinciCode),\n", "(JessicaThompson)-[:REVIEWED {summary:'You had me at Jerry', rating:92}]->(JerryMaguire)\n", "\n", "WITH TomH as a\n", "MATCH (a)-[:ACTED_IN]->(m)<-[:DIRECTED]-(d) RETURN a,m,d LIMIT 10;" ], "metadata": { "id": "QFbjo1k24YEY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "You can query the database via **cy2py** in this simple way" ], "metadata": { "id": "peqcEHj0b35T" } }, { "cell_type": "code", "source": [ "%%cypher\n", "CALL apoc.meta.graph()" ], "metadata": { "id": "BfFOTNkncMqp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As you can see the model is exactely how we expect!" ], "metadata": { "id": "sGu-zpk8nY5r" } }, { "cell_type": "code", "source": [ "# this step is MANDATORY for the exercises\n", "from neo4j import GraphDatabase\n", "neo4j_driver = GraphDatabase.driver(neo4j_url, auth=(neo4j_user, neo4j_password))" ], "metadata": { "id": "_zZF1guo58cc" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "c8bQe1b-7RY-" }, "source": [ "# Read data from Neo4j into Spark\n" ] }, { "cell_type": "markdown", "source": [ "The query above generates the following graph model:\n", "\n" ], "metadata": { "id": "ovoUnDmocaxK" } }, { "cell_type": "markdown", "metadata": { "id": "B1LLHYf1CsPh" }, "source": [ "## Read nodes via `labels` option" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "omdSk6ShCqfA" }, "outputs": [], "source": [ "movies_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('labels', ':Movie')\n", " .load())" ] }, { "cell_type": "markdown", "metadata": { "id": "RyglSgXnQcar" }, "source": [ "### Schema description" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f9AaUINjPH4n" }, "outputs": [], "source": [ "movies_df.printSchema()" ] }, { "cell_type": "markdown", "source": [ "The `movies_df` contains a set of fields, the first two (generally) are always:\n", "\n", "* `` which represents the internal Neo4j id\n", "* `` which represents the list of labels attached to the node\n", "\n", "All other properties are taken from the node via schema resolution by using APOC or Cypher queries" ], "metadata": { "id": "jxLcYSkgZ1xf" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "omGjaw5QDgS-" }, "outputs": [], "source": [ "movies_df" ] }, { "cell_type": "markdown", "metadata": { "id": "7-KTRC5HD5sO" }, "source": [ "### Exercise\n", "\n", "Read all the `Person` nodes store them into a Python variable called `person_df` and then verify the results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZhnsFC9KEsLp" }, "outputs": [], "source": [ "person_df = # write your spark code here" ] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "person_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('labels', ':Person')\n", " .load())\n", "```\n", "\n", "
\n", "\n" ], "metadata": { "id": "O4WEzidAZBh-" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D5_zXyweE-QM" }, "outputs": [], "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " created the person_df\n", "\"\"\"\n", "\n", "assert person_df.count() == 133\n", "assert person_df.schema.fieldNames() == ['', '', 'name', 'born']\n", "assert person_df.collect()[0][\"\"] == ['Person']\n", "print(\"All assertion are successfuly satisfied. Congrats you created your first DataFrame\")" ] }, { "cell_type": "markdown", "metadata": { "id": "m1hgGMLCRoZx" }, "source": [ "## Read relationships via `relationship` option" ] }, { "cell_type": "markdown", "metadata": { "id": "HgPockV0I5Q3" }, "source": [ "There are two way to transform relationships into DataFrame\n", "\n", "* having all the node and relationship data flattened into the DataFrame\n", "* having all the node properties in maps and the relationship data as columns" ] }, { "cell_type": "markdown", "metadata": { "id": "m0DBqZLtKtvX" }, "source": [ "### DataFrame with flattened data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "796cuMwXR2zi" }, "outputs": [], "source": [ "actedin_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('relationship', 'ACTED_IN')\n", " .option('relationship.source.labels', ':Person')\n", " .option('relationship.target.labels', ':Movie')\n", " .load())" ] }, { "cell_type": "markdown", "source": [ "### Schema description" ], "metadata": { "id": "yzyviI5vXO4K" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5uDWZqoySNGc" }, "outputs": [], "source": [ "actedin_df.printSchema()" ] }, { "cell_type": "markdown", "source": [ "The `movies_df` contains a set of fields, the first two (generally) are always:\n", "\n", "* `` which represents the internal Neo4j relationship id\n", "* `` which represents the relationship type\n", "* `` which represents the internal Neo4j node id\n", "* `` which represents the list of labels attached to the node\n", "* `rel.*` which represents the properties attached to the relationship\n", "* `source/target.*` which represents the properties attached to the node\n", "\n", "All other properties are taken from the node via schema resolution by using APOC or Cypher queries" ], "metadata": { "id": "2dB9DL7KZxrX" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VPHDTL-IUX2X" }, "outputs": [], "source": [ "actedin_df" ] }, { "cell_type": "markdown", "metadata": { "id": "RoPVDptGKy_m" }, "source": [ "### DataFrame with nodes as map" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8VxgDlBXIt_h" }, "outputs": [], "source": [ "actedin_map_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('relationship.nodes.map', True)\n", " .option('relationship', 'ACTED_IN')\n", " .option('relationship.source.labels', ':Person')\n", " .option('relationship.target.labels', ':Movie')\n", " .load())" ] }, { "cell_type": "markdown", "source": [ "### Schema description" ], "metadata": { "id": "nbGAYcd7YMOp" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gF5-m4BbI2Ib" }, "outputs": [], "source": [ "actedin_map_df.printSchema()" ] }, { "cell_type": "markdown", "source": [ "The `movies_df` contains a set of fields, the first two (generally) are always:\n", "\n", "* `` which represents the internal Neo4j relationship id\n", "* `` which represents the relationship type\n", "* `` which represents a map with node values\n", "* `rel.*` which represents the properties attached to the relationship\n", "\n", "All other properties are taken from the node via schema resolution by using APOC or Cypher queries" ], "metadata": { "id": "Zuu42SpfZ502" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UUYUwE3CLA_r" }, "outputs": [], "source": [ "actedin_map_df" ] }, { "cell_type": "code", "source": [ "actedin_map_df.collect()[0][\"\"]" ], "metadata": { "id": "hHPq7neyYDSx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Exercise\n", "\n", "Read all the `DIRECTED` relationships" ], "metadata": { "id": "Viop-9_thCbF" } }, { "cell_type": "code", "source": [ "directed_df = # write your spark code here" ], "metadata": { "id": "j0tTsk59hhLh" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "directed_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('relationship', 'DIRECTED')\n", " .option('relationship.source.labels', ':Person')\n", " .option('relationship.target.labels', ':Movie')\n", " .load())\n", "```\n", "\n", "
" ], "metadata": { "id": "VQYyYSMpj2lf" } }, { "cell_type": "code", "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " created the directed_df\n", "\"\"\"\n", "\n", "assert directed_df.count() == 44\n", "assert directed_df.schema.fieldNames() == ['',\n", " '',\n", " '',\n", " '',\n", " 'source.name',\n", " 'source.born',\n", " '',\n", " '',\n", " 'target.title',\n", " 'target.tagline',\n", " 'target.released']\n", "assert directed_df.collect()[0][\"\"] == 'DIRECTED'\n", "print(\"All assertion are successfuly satisfied. Congrats you created your first relationship DataFrame\")" ], "metadata": { "id": "HRwaJ8PvhudP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Read arbitrary data via Cypher query" ], "metadata": { "id": "hCpzW904dS2r" } }, { "cell_type": "code", "source": [ "cypher_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('query', '''\n", " // Extend Tom Hanks co-actors, to find co-co-actors who haven't worked with Tom Hanks\n", " MATCH (tom:Person {name:\"Tom Hanks\"})-[:ACTED_IN]->(m)<-[:ACTED_IN]-(coActors),\n", " (coActors)-[:ACTED_IN]->(m2)<-[:ACTED_IN]-(cocoActors)\n", " WHERE NOT (tom)-[:ACTED_IN]->()<-[:ACTED_IN]-(cocoActors)\n", " AND tom <> cocoActors\n", " RETURN cocoActors.name AS Recommended, count(*) AS Strength\n", " ORDER BY Strength DESC\n", " ''')\n", " .load())" ], "metadata": { "id": "hplBy0b_dhnb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Schema description" ], "metadata": { "id": "tRZPA6xWeSCT" } }, { "cell_type": "code", "source": [ "cypher_df.printSchema()" ], "metadata": { "id": "hU-JfgNNeL5f" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "cypher_df" ], "metadata": { "id": "8IcUQsileXQ7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "// Just for debugging purposes let's check the same query directly from the database\n", "MATCH (tom:Person {name:\"Tom Hanks\"})-[:ACTED_IN]->(m)<-[:ACTED_IN]-(coActors),\n", " (coActors)-[:ACTED_IN]->(m2)<-[:ACTED_IN]-(cocoActors)\n", "WHERE NOT (tom)-[:ACTED_IN]->()<-[:ACTED_IN]-(cocoActors)\n", " AND tom <> cocoActors\n", "RETURN cocoActors.name AS Recommended, count(*) AS Strength\n", "ORDER BY Strength DESC\n", "LIMIT 20" ], "metadata": { "id": "CWtNAeoN4O6S" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Exercise\n", "\n", "Return all the actors that have also directed a movie.\n", "\n", "The returned DataFrame must have 3 columns:\n", "\n", "* `name` the actor name\n", "* `acted_in` a list of unique films (title) where he acted in\n", "* `directed` a list of unique films (title) where he was a director" ], "metadata": { "id": "Hyq4KsKQegdE" } }, { "cell_type": "code", "source": [ "your_cypher_df = # write your spark code here" ], "metadata": { "id": "0h1FFuYxej2f" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "your_cypher_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('query', '''\n", " MATCH (p:Person)\n", " MATCH (p)-[:ACTED_IN]->(m:Movie)\n", " MATCH (p)-[:DIRECTED]->(m1:Movie)\n", " RETURN p.name AS name, collect(m.title) AS acted_in, collect(m1.title) AS directed\n", " ''')\n", " .load())\n", "```\n", "\n", "
" ], "metadata": { "id": "CQncCY52lCxv" } }, { "cell_type": "code", "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " created the your_cypher_df\n", "\"\"\"\n", "\n", "assert your_cypher_df.count() == 5\n", "assert your_cypher_df.schema.fieldNames() == ['name', 'acted_in', 'directed']\n", "your_cypher_df_collect = your_cypher_df.collect()\n", "assert frozenset(map(lambda row: row['name'], your_cypher_df_collect)) == frozenset(['Clint Eastwood',\n", " 'Danny DeVito',\n", " 'James Marshall',\n", " 'Werner Herzog',\n", " 'Tom Hanks'])\n", "assert frozenset(map(lambda row: frozenset(row['acted_in']), your_cypher_df_collect)) == set([\n", " frozenset([\"Apollo 13\", \"You've Got Mail\", \"A League of Their Own\", \"Joe Versus the Volcano\", \"That Thing You Do\", \"The Da Vinci Code\", \"Cloud Atlas\", \"Cast Away\", \"The Green Mile\", \"Sleepless in Seattle\", \"The Polar Express\", \"Charlie Wilson's War\"]),\n", " frozenset([\"What Dreams May Come\"]),\n", " frozenset([\"Unforgiven\"]),\n", " frozenset([\"A Few Good Men\"]),\n", " frozenset([\"Hoffa\", \"One Flew Over the Cuckoo's Nest\"])\n", " ])\n", "assert frozenset(map(lambda row: frozenset(row['directed']), your_cypher_df_collect)) == set([\n", " frozenset([\"That Thing You Do\"]),\n", " frozenset([\"RescueDawn\"]),\n", " frozenset([\"Unforgiven\"]),\n", " frozenset([\"V for Vendetta\", \"Ninja Assassin\"]),\n", " frozenset([\"Hoffa\"])\n", " ])\n", "print(\"All assertion are successfuly satisfied. Congrats you created your first cypher dataframe\")" ], "metadata": { "id": "xG_7Wy-_go5V" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hFpA11aK8ADf" }, "source": [ "# Write data from Spark to Neo4j" ] }, { "cell_type": "markdown", "source": [ "## The graph model\n", "\n", "Our goal is to create this simple graph model\n", "\n", "" ], "metadata": { "id": "Mx84Qi1PcHF_" } }, { "cell_type": "markdown", "metadata": { "id": "Trt-L_9pMQf1" }, "source": [ "### Download The Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h4o07NpuJmaG" }, "outputs": [], "source": [ "!wget -q https://s3.amazonaws.com/dev.assets.neo4j.com/wp-content/uploads/desktop-csv-import.zip" ] }, { "cell_type": "markdown", "source": [ "The zip is composed of three files:\n", "* products.csv: describes the products and has three columns (and no header)\n", "* orders.csv: has three columns (with the header) and describe the order\n", "* order-details.csv: is the \"join\" table between orders and products; it has three columns with header" ], "metadata": { "id": "KKfl_ZyhYYWj" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nduIG7H_J0-A" }, "outputs": [], "source": [ "!unzip desktop-csv-import.zip" ] }, { "cell_type": "markdown", "metadata": { "id": "w5kaTPoEQNkT" }, "source": [ "### Explore the Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0TZgi_E1gAv" }, "outputs": [], "source": [ "products_df = (spark.read\n", " .format('csv')\n", " .option('inferSchema', True)\n", " .option('path', '/content/desktop-csv-import/products.csv')\n", " .load())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0tptDEUn2WO6" }, "outputs": [], "source": [ "products_df.printSchema()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZSW40PxjKvgf" }, "outputs": [], "source": [ "products_df" ] }, { "cell_type": "markdown", "source": [ "As you can see in the schema, colums have no name, just a generic `_c` prefix concatenated with an index.\n", "The three columns describe:\n", "* `_c0` is the `id` of the product\n", "* `_c1` is the `name`\n", "* `_c2` is the `price`\n", "\n", "Let's rename these columns!" ], "metadata": { "id": "VrKERh4uXu8J" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7xabRYlXQZl4" }, "outputs": [], "source": [ "products_df = (products_df.withColumnRenamed('_c0', 'id')\n", " .withColumnRenamed('_c1', 'name')\n", " .withColumnRenamed('_c2', 'price'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w08HTcIeQuNn" }, "outputs": [], "source": [ "products_df.printSchema()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XCkyYIyzQv0X" }, "outputs": [], "source": [ "products_df" ] }, { "cell_type": "markdown", "source": [ "## Write nodes via `label` option" ], "metadata": { "id": "XL9oBe0-m680" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oy4eoeAMRWxc" }, "outputs": [], "source": [ "(products_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('append')\n", " .option('labels', ':Product')\n", " .save())" ] }, { "cell_type": "markdown", "source": [ "Let's check if the nodes are in the database!" ], "metadata": { "id": "MLxewMeFsQj3" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (n:Product)\n", "RETURN n\n", "LIMIT 10" ], "metadata": { "id": "2yoBPvmmsUqt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now just to be sure that we loaded all the nodes into Neo4j we'll count the dataframe and the nodes inside the database" ], "metadata": { "id": "10PLvdmZ0tZT" } }, { "cell_type": "code", "source": [ "products_df.count()" ], "metadata": { "id": "AmwadsK702Sl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (n:Product)\n", "RETURN count(n)" ], "metadata": { "id": "J8e9jZPG06DL" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "If the two counts are equal, all the data has been properly imported." ], "metadata": { "id": "T8XAaw3K1Mga" } }, { "cell_type": "markdown", "source": [ "### Create Constraints\n", "\n", "Oh but wait, we forgot to create constraints!!! if we go into the Neo4j browser and excute the following query:\n", "\n", "```cypher\n", "show constraints\n", "```\n", "\n", "We should get the constraints of the movie database, but not one for `Product`.\n", "\n", "So please create the constaints for the node `Product`:\n", "\n", "```cypher\n", "CREATE CONSTRAINT product_id FOR (p:Product) REQUIRE p.id IS UNIQUE;\n", "```\n", "\n", "But if you want, you can also delegate the Spark connector to perform optimizations pre-processing by usign the option `schema.optimization.type` which can assume three values:\n", "\n", "* `INDEX`: it creates only indexes on provided nodes.\n", "* `NODE_CONSTRAINTS`: it creates only indexes on provided nodes.\n", "\n", "So let's create the `Order` node with by let the connector creating the constraints for you" ], "metadata": { "id": "jFL_tQk0tsni" } }, { "cell_type": "code", "source": [ "%%cypher\n", "// Check the constraints\n", "SHOW CONSTRAINTS" ], "metadata": { "id": "T3PUsfIvsi23" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "// Create the constraint for Product node\n", "CREATE CONSTRAINT product_id IF NOT EXISTS FOR (p:Product) REQUIRE p.id IS UNIQUE;" ], "metadata": { "id": "iMF68OU20XhE" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "// Check (again) the constraints\n", "SHOW CONSTRAINTS" ], "metadata": { "id": "w_wSMYvz0fz_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As you can see now we have the `product_id` constraint in the result list" ], "metadata": { "id": "2JwYLMva1VnN" } }, { "cell_type": "code", "source": [ "orders_df = (spark.read\n", " .format('csv')\n", " .option('inferSchema', True)\n", " .option('header', True)\n", " .option('path', '/content/desktop-csv-import/orders.csv')\n", " .load())" ], "metadata": { "id": "m5ge2R_Ggd3K" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "orders_df.printSchema()" ], "metadata": { "id": "7UGwHtuJwFU4" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "orders_df" ], "metadata": { "id": "uGmUP5ZMwJkm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# we cast orderDate to timestamp in order to have it converted properly into Neo4j\n", "orders_df = orders_df.selectExpr('orderID AS id', 'CAST(orderDate AS TIMESTAMP) AS date', 'shipCountry')" ], "metadata": { "id": "InwYglcUwXNy" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "orders_df.printSchema()" ], "metadata": { "id": "kSJmW81cw_ES" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "orders_df" ], "metadata": { "id": "7VN5RAizxSr2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "(orders_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('labels', ':Order')\n", " .option('schema.optimization.type', 'NODE_CONSTRAINTS')\n", " # this is necessary in order to specify what is the constraint field\n", " .option('node.keys', 'id')\n", " .save())" ], "metadata": { "id": "0rwKB9V-xTzu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now let's check if the connector has created the constraint for us" ], "metadata": { "id": "A_T2beEJx3Ho" } }, { "cell_type": "code", "source": [ "%%cypher\n", "SHOW CONSTRAINTS" ], "metadata": { "id": "hDxdKyFY1sDT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As you can see the we have the constraint `spark_NODE_CONSTRAINTS_Order_id` that has been create by the Spark connector itself.\n", "\n", "Now just because we're courious let's check if the data has been propertly loaded.\n", "\n", "The first thing to check is if the count of the Dataframe and the nodes in Neo4j matches." ], "metadata": { "id": "H5Bne-Mq1vsO" } }, { "cell_type": "code", "source": [ "orders_df.count()" ], "metadata": { "id": "mANrH-Zt2ShO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (o:Order)\n", "RETURN count(o)" ], "metadata": { "id": "vzCCVYAK2V0X" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now we want to check if the data has been loaded with the proper data type, in particular we created a new column `date` by casting `orderDate` to `TIMESTAMP`." ], "metadata": { "id": "X8BBxCQg2dFM" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (o:Order)\n", "RETURN apoc.meta.cypher.type(o.date), count(o)" ], "metadata": { "id": "VN7XMI192xAP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "So all the `date` values have the same type." ], "metadata": { "id": "PpXyKtJF3Lk4" } }, { "cell_type": "markdown", "source": [ "### Exercise\n", "\n", "Given the `football_teams_df` and `football_player_df` below please:\n", "* for `football_teams_df` insert it as nodes with label `:FootballTeam` in Neo4j.\n", "* for `football_player_df` insert it as nodes with label `:FootballPlayer` in Neo4j.\n", "\n", "Create for both of them constraints via the schema optimization feature:\n", "* for `football_teams_df` the key must be the property `id`\n", "* for `football_player_df` the key must be the property `name`" ], "metadata": { "id": "Zi4Dl2LqmVmN" } }, { "cell_type": "code", "source": [ "football_teams_df = spark.createDataFrame([{'id': 1, 'name': 'AC Milan'}, {'id': 2, 'name': 'FC Internazionale'}])\n", "football_player_df = spark.createDataFrame([\n", " {'name': 'Zlatan Ibrahimovic'},\n", " {'name': 'Sandro Tonali'},\n", " {'name': 'Nicolò Barella'},\n", " {'name': 'Marcelo Brozovic'}])" ], "metadata": { "id": "21tFDtgAmVON" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# write your spark code that persist football_teams_df and football_player_df here" ], "metadata": { "id": "utpvz-fI6blD" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "# write the teams\n", "(football_teams_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('labels', ':FootballTeam')\n", " .option('schema.optimization.type', 'NODE_CONSTRAINTS')\n", " .option('node.keys', 'id')\n", " .save())\n", "# write the players\n", "(football_player_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('labels', ':FootballPlayer')\n", " .option('schema.optimization.type', 'NODE_CONSTRAINTS')\n", " .option('node.keys', 'name')\n", " .save())\n", "```\n", "\n", "
" ], "metadata": { "id": "xeVxPe7PmEy8" } }, { "cell_type": "code", "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " persisted football_teams_df and\n", " football_player_df in Neo4j as nodes\n", "\"\"\"\n", "\n", "with neo4j_driver.session() as session:\n", " # count football players\n", " football_players = session.read_transaction(lambda tx: (tx.run('''\n", " MATCH (p:FootballPlayer)\n", " WHERE p.name IN ['Zlatan Ibrahimovic', 'Sandro Tonali',\n", " 'Nicolò Barella', 'Marcelo Brozovic']\n", " RETURN count(p) AS count\n", " ''').single()['count']))\n", " assert football_players == 4\n", "\n", " # count football teams\n", " football_teams = session.read_transaction(lambda tx: (tx.run('''\n", " MATCH (p:FootballTeam)\n", " WHERE p.name IN ['AC Milan', 'FC Internazionale']\n", " RETURN count(p) AS count\n", " ''').single()['count']))\n", " assert football_teams == 2\n", "\n", " # count constraints\n", " football_constraints = session.read_transaction(lambda tx: (tx.run('''\n", " SHOW CONSTRAINTS YIELD name\n", " WHERE name IN ['spark_NODE_CONSTRAINTS_FootballPlayer_name', 'spark_NODE_CONSTRAINTS_FootballTeam_id']\n", " RETURN count(*) AS count\n", " ''').single()['count']))\n", " assert football_constraints == 2\n", "\n", "print(\"All assertion are successfuly satisfied. Congrats you saved your first Node DataFrame into Neo4j!\")" ], "metadata": { "id": "a5zQCEyK6h5f" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Write relationships via `relationship` option" ], "metadata": { "id": "QltYkhIuy2Kc" } }, { "cell_type": "code", "source": [ "order_details_df = (spark.read\n", " .format('csv')\n", " .option('inferSchema', True)\n", " .option('header', True)\n", " .option('path', '/content/desktop-csv-import/order-details.csv')\n", " .load())" ], "metadata": { "id": "y3U_X0b7x2UN" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "order_details_df.printSchema()" ], "metadata": { "id": "zFqKW-j-zRyk" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "order_details_df" ], "metadata": { "id": "5s5f3W984Jc0" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Please remember that this is the pattern that we want to ingest:\n", "\n", "\n", "" ], "metadata": { "id": "aZWSOtNK2qEw" } }, { "cell_type": "code", "source": [ "(order_details_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('relationship', 'CONTAINS')\n", " .option('relationship.save.strategy', 'keys')\n", " .option('relationship.source.labels', ':Product')\n", " .option('relationship.source.save.mode', 'Match')\n", " .option('relationship.source.node.keys', 'productID:id')\n", " .option('relationship.target.labels', ':Order')\n", " .option('relationship.target.save.mode', 'Match')\n", " .option('relationship.target.node.keys', 'orderID:id')\n", " .option('relationship.properties', 'quantity:quantityOrdered')\n", " .save())" ], "metadata": { "id": "rFo3KWA90rmZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now let's check the count for both Dataframe and relationships in Neo4j" ], "metadata": { "id": "1PaC36iZ3bNf" } }, { "cell_type": "code", "source": [ "order_details_df.count()" ], "metadata": { "id": "OUZ5FYHP3qvj" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (p:Product)-[r:CONTAINS]->(o:Order)\n", "RETURN count(r)" ], "metadata": { "id": "ex8o2pSo34cI" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Exercise\n", "\n", "Given the `team_player_df` create a relationship between `:FootballPlayer` and `:FootballTeam` of type `PLAYS_FOR`:\n", "\n", "```cypher\n", "(:FootballPlayer)-[:PLAYS_FOR]->(:FootballTeam)\n", "```" ], "metadata": { "id": "lTBo369687lC" } }, { "cell_type": "code", "source": [ "team_player_df = spark.createDataFrame([\n", " {'id': 1, 'football_player': 'Zlatan Ibrahimovic'},\n", " {'id': 1, 'football_player': 'Sandro Tonali'},\n", " {'id': 2, 'football_player': 'Nicolò Barella'},\n", " {'id': 2, 'football_player': 'Marcelo Brozovic'}])" ], "metadata": { "id": "gi9kB0l49f8H" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# write your spark code that persist team_player_df here" ], "metadata": { "id": "i7ZjASDx_Kiy" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "(team_player_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('relationship', 'PLAYS_FOR')\n", " .option('relationship.save.strategy', 'keys')\n", " .option('relationship.source.labels', ':FootballPlayer')\n", " .option('relationship.source.save.mode', 'Match')\n", " .option('relationship.source.node.keys', 'football_player:name')\n", " .option('relationship.target.labels', ':FootballTeam')\n", " .option('relationship.target.save.mode', 'Match')\n", " .option('relationship.target.node.keys', 'id')\n", " .save())\n", "```\n", "\n", "
" ], "metadata": { "id": "oNRJevUSm0Wi" } }, { "cell_type": "code", "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " persisted team_player_df as relationships\n", "\"\"\"\n", "\n", "with neo4j_driver.session() as session:\n", " # count relationships\n", " def count_relationships(tx):\n", " result = tx.run('''\n", " MATCH (p:FootballPlayer)-[:PLAYS_FOR]->(t:FootballTeam)\n", " RETURN t.name AS team, collect(p.name) AS players\n", " ORDER by team\n", " ''')\n", " return [{'team': record['team'], 'players': set(record['players'])} for record in result]\n", "\n", " actual = session.read_transaction(count_relationships)\n", " expected = [\n", " {'team': 'AC Milan', 'players': frozenset(['Zlatan Ibrahimovic', 'Sandro Tonali'])},\n", " {'team': 'FC Internazionale', 'players': frozenset(['Nicolò Barella', 'Marcelo Brozovic'])}\n", " ]\n", " assert actual == expected\n", "\n", "print(\"All assertion are successfuly satisfied. Congrats you saved your first Relationship DataFrame into Neo4j!\")" ], "metadata": { "id": "LDYbmoUx_Owb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Write custom graphs via Cypher Query\n", "\n", "Now let's consider that two actors created an order and bought several products, and we want to add information in our database." ], "metadata": { "id": "vBq9NCWlZkHw" } }, { "cell_type": "code", "source": [ "actor_orders = [\n", " {'actor_name': 'Cuba Gooding Jr.', 'order_id': 1, 'products': [11, 42, 72], 'quantities': [1, 2, 3], 'order_date': '2022-06-07 00:00:00'},\n", " {'actor_name': 'Tom Hanks', 'order_id': 2, 'products': [24, 55, 75], 'quantities': [3, 2, 1], 'order_date': '2022-06-06 00:00:00'}\n", "]\n", "\n", "actor_orders_df = spark.createDataFrame(actor_orders)" ], "metadata": { "id": "3y_yEOouaHxe" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "actor_orders_df.printSchema()" ], "metadata": { "id": "hYo7nlgvbfdm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "actor_orders_df" ], "metadata": { "id": "fOkW4w2lbhZP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "In this case please go into Neo4j and create the following constraint:\n", "\n", "```cypher\n", "CREATE CONSTRAINT person_name FOR (p:Person) REQUIRE p.name is UNIQUE;\n", "```" ], "metadata": { "id": "q0wAIB5l7qp9" } }, { "cell_type": "code", "source": [ "%%cypher\n", "// if you didn't before create the constraint on Person.name\n", "CREATE CONSTRAINT person_name IF NOT EXISTS FOR (p:Person) REQUIRE p.name is UNIQUE;" ], "metadata": { "id": "-_Lxm_zV4qw2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "(actor_orders_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('query', '''\n", " MATCH (person:Person {name: event.actor_name})\n", " MERGE (order:Order {id: event.order_id, date: datetime(replace(event.order_date, ' ', 'T'))})\n", " MERGE (person)-[:CREATED]->(order)\n", " WITH event, order\n", " UNWIND range(0, size(event.products) - 1) AS index\n", " MATCH (product:Product {id: event.products[index]})\n", " MERGE (product)-[:CONTAINS{quantityOrdered: event.quantities[index]}]->(order)\n", " ''')\n", " .save())" ], "metadata": { "id": "cm3RQyLbbjue" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "What we expect now is that for the two actors there are two orders one per each, then each order contains three products." ], "metadata": { "id": "61LUSN4F6pQq" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (a:Person)-[:CREATED]->(o:Order)<-[c:CONTAINS]-(p:Product)\n", "WHERE a.name IN ['Cuba Gooding Jr.', 'Tom Hanks']\n", "RETURN a.name, o.id, o.date, p.name, c.quantityOrdered" ], "metadata": { "id": "8fVMQiTf61mN" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Exercise\n", "\n", "Given `neo4j_resources_df` build a small Knowledge Graph in Neo4j with the following structure:\n", "\n", "```cypher\n", "(:Author{name})-[:CREATED]->(:Resource{name})-[:HAS_TAG]->(:Tag{name})\n", "```" ], "metadata": { "id": "pn7IM8me9R3I" } }, { "cell_type": "code", "source": [ "neo4j_resources_df = spark.createDataFrame([\n", " {'author': 'LARUS Business Automation', 'resource': 'Galileo.XAI', 'tags': ['Graph Machine Learning', 'Neo4j', 'Explainable AI', 'Artificial Intelligence']},\n", " {'author': 'Neo4j', 'resource': 'Graph Data Science Library', 'tags': ['Graph Machine Learning', 'Algorithms']},\n", " {'author': 'Michael Hunger', 'resource': 'APOC', 'tags': ['Graph Data Integration', 'Graph Algorithms']}\n", "])" ], "metadata": { "id": "_wmoLl8d9RVz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "neo4j_resources_df" ], "metadata": { "id": "NkG1jCynLXgJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# write your spark code that persist neo4j_resources_df here" ], "metadata": { "id": "8EQmY-qhsbi1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "(neo4j_resources_df.write\n", " .format('org.neo4j.spark.DataSource')\n", " .mode('overwrite')\n", " .option('query', '''\n", " MERGE (a:Author {name: event.author})\n", " MERGE (r:Resource {name: event.resource})\n", " MERGE (a)-[:CREATED]->(r)\n", " WITH a, r, event\n", " UNWIND event.tags AS tag\n", " MERGE (t:Tag{name: tag})\n", " MERGE (r)-[:HAS_TAG]->(t)\n", " ''')\n", " .save())\n", "```\n", "\n", "
" ], "metadata": { "id": "KCfw_saanywo" } }, { "cell_type": "code", "source": [ "\"\"\"\n", " This paragraph is for validating the code the you\n", " wrote above, please execute it after you\n", " persisted neo4j_resources_df as Cypher query\n", "\"\"\"\n", "\n", "with neo4j_driver.session() as session:\n", " # count relationships\n", " def check_graph_consistency(tx):\n", " result = tx.run('''\n", " MATCH (a:Author)-[:CREATED]->(r:Resource)-[:HAS_TAG]->(t:Tag)\n", " RETURN a.name AS author, r.name AS resource, collect(t.name) AS tags\n", " ORDER By author\n", " ''')\n", " return [{'author': record['author'], 'resource': record['resource'], 'tags': set(record['tags'])} for record in result]\n", "\n", " actual = session.read_transaction(check_graph_consistency)\n", " expected = [\n", " {'author': 'LARUS Business Automation', 'resource': 'Galileo.XAI', 'tags': frozenset(['Graph Machine Learning', 'Neo4j', 'Explainable AI', 'Artificial Intelligence'])},\n", " {'author': 'Michael Hunger', 'resource': 'APOC', 'tags': frozenset(['Graph Data Integration', 'Graph Algorithms'])},\n", " {'author': 'Neo4j', 'resource': 'Graph Data Science Library', 'tags': frozenset(['Graph Machine Learning', 'Algorithms'])}\n", " ]\n", " assert actual == expected\n", "\n", "print(\"All assertion are successfuly satisfied. Congrats you saved your first Knowledge Graph DataFrame into Neo4j!\")" ], "metadata": { "id": "LwqbSsEcsgDi" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "_LbqufNZMj6-" }, "execution_count": null, "outputs": [] } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/neo4j_data_science.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "Open this notebook in Google Colab \n", " \"Open\n", "" ], "metadata": { "id": "ciNaixnkx1vj" } }, { "cell_type": "markdown", "source": [ "# Example of a Simple data science workflow with Neo4j and Spark" ], "metadata": { "id": "zADiJjnuVfq2" } }, { "cell_type": "markdown", "source": [ "This notebook contains a set of examples that explains how the Neo4j Spark connector can fit in you Data Scinece workflow, how you can combine Spark Neo4j and the Graph Data Science library to extract insights from your data and mostly important it allows you to test your knowledge with a set of exercises after each section.\n", "\n", "If you have any questions or problems feel free to write a post in the [Neo4j community forum](https://community.neo4j.com/) or in [Discord](https://discord.com/invite/neo4j).\n", "\n", "If you want more exercises feel free to open an issue in the [GitHub repository](https://github.com/neo4j/neo4j-spark-connector).\n", "\n", "Enjoy!" ], "metadata": { "id": "nLucMn17V0YK" } }, { "cell_type": "markdown", "source": [ "# Notes about this notebook\n", "\n", "This code contains a simple data science workflow that combines Neo4j's Graph Data Science Library with the Neo4j Connector for Apache Spark.\n", "\n", "Going forward you'll find code examples in:\n", "\n", "* PySpark\n", "* PySpark Pandas\n", "\n", "You can choose to navigate by using one of them, or both, but we suggest you do one at time to ensure you understand the APIs." ], "metadata": { "id": "pWWY8190RB98" } }, { "cell_type": "markdown", "source": [ "# Create the sandbox instance\n", "\n", "You can easily spin-up a Neo4j sandbox by click [here](https://sandbox.neo4j.com/?usecase=fraud-detection)\n", "\n", "After that you'll be redirect in a webpage like this:\n", "\n", "\n", "\n", "Please click in the **Connection details tab** and copy your connection parameters into the Python variables below" ], "metadata": { "id": "3hCQBmBKVaHm" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ttxf62TPVP-w" }, "outputs": [], "source": [ "neo4j_url = \"\" # put your neo4j url here" ] }, { "cell_type": "code", "source": [ "neo4j_user = \"neo4j\" # put your neo4j user here" ], "metadata": { "id": "-lPr1hfIGtfL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "neo4j_password = \"\" # put your neo4j password here" ], "metadata": { "id": "yoI29jjvGvlX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Configure the Spark Environment" ], "metadata": { "id": "Capd99x5G2rm" } }, { "cell_type": "code", "source": [ "spark_version = '3.3.4'" ], "metadata": { "id": "OiHMiko1-Qf7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!apt-get install openjdk-17-jdk-headless -qq > /dev/null" ], "metadata": { "id": "qdjzLBDzGx5l" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!wget -q https://dlcdn.apache.org/spark/spark-$spark_version/spark-$spark_version-bin-hadoop3.tgz" ], "metadata": { "id": "7JT9OKhzG7Lq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A3gsnSHl0F99" }, "outputs": [], "source": [ "!tar xf spark-$spark_version-bin-hadoop3.tgz" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hSBQWKs90vSx" }, "outputs": [], "source": [ "!pip install -q findspark" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tnW0a1Gj080k" }, "outputs": [], "source": [ "import os\n", "os.environ[\"JAVA_HOME\"] = \"/usr/lib/jvm/java-17-openjdk-amd64\"\n", "os.environ[\"SPARK_HOME\"] = f\"/content/spark-{spark_version}-bin-hadoop3\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dlUBSezK1DpZ" }, "outputs": [], "source": [ "import findspark\n", "findspark.init()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dOUJ-W871Tur" }, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "spark = (SparkSession.builder\n", " .master('local[*]')\n", " .appName('Data science workflow with Neo4j and Spark')\n", " .config('spark.ui.port', '4050')\n", " # Just to show dataframes as tables\n", " #.config('spark.sql.repl.eagerEval.enabled', False)\n", " .config('spark.jars.packages', 'org.neo4j:neo4j-connector-apache-spark_2.12:5.1.0_for_spark_3')\n", " # As we're using always the same database instance we'll\n", " # define them as global variables\n", " # so we don't need to repeat them each time\n", " .config(\"neo4j.url\", neo4j_url)\n", " .config(\"neo4j.authentication.type\", \"basic\")\n", " .config(\"neo4j.authentication.basic.username\", neo4j_user)\n", " .config(\"neo4j.authentication.basic.password\", neo4j_password)\n", " .getOrCreate())\n", "spark" ] }, { "cell_type": "code", "source": [ "# import utility functions that we'll use in the notebook\n", "from pyspark.sql.types import *\n", "from pyspark.sql.functions import *" ], "metadata": { "id": "pghCcGnJWcZQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Import PySpark Pandas\n", "\n", "Pandas API on Apache Spark (PySpark) enables data scientists and data engineers to run their existing Pandas code on Spark. Prior to this API, you had to do a significant code rewrite from Pandas DataFrame to PySpark DataFrame which is time-consuming and error-prone.\n", "\n", "In this notebook we'll use both PySpark Dataframes and and PySpark Pandas.\n", "\n", "The only thing that we need to do is to import the library using the statement below." ], "metadata": { "id": "klQ2Ah6CFBV1" } }, { "cell_type": "code", "source": [ "import pyspark.pandas as ps" ], "metadata": { "id": "lDkBcHySCBT0" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "\n", "## Exercises prerequisite\n", "\n", "In this notebook we and going to test your knowledge. Some of the exercises require the Neo4j Python driver to check if the exercises are being solved correctly.\n", "\n", "*Neo4j Python Driver is required only for verifying the exercises when you persist data from Spark to Neo4j*\n", "\n", "**It's not required by the Spark connector!!!**\n", "\n", "We'll use [Cy2Py](https://github.com/conker84/cy2py), a Jupyter extension that easily allows you to connect to Neo4j and visualize data from Jupyter notebooks.\n", "For a detailed instruction about how to use it please dive into [this example](https://github.com/conker84/cy2py/blob/main/examples/Neo4j_Crime_Investigation_Dataset.ipynb)" ], "metadata": { "id": "b6_YNZnZ5GdT" } }, { "cell_type": "code", "source": [ "!pip install -q cy2py" ], "metadata": { "id": "f5ZZJylo5Bbz" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "let's load the extension" ], "metadata": { "id": "uKYEPEgOcG2b" } }, { "cell_type": "code", "source": [ "%load_ext cy2py" ], "metadata": { "id": "38EeXF6icKOK" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "You can query the database via **cy2py** in this simple way" ], "metadata": { "id": "peqcEHj0b35T" } }, { "cell_type": "code", "source": [ "# define the colors for the nodes\n", "colors = {\n", " ':Client': '#D18711',\n", " ':Bank': '#0541B2',\n", " ':Merchant': '#9E14AA',\n", " ':Mule': '#6113A3',\n", " ':CashIn': '#328918',\n", " ':CashOut': '#C1A23D',\n", " ':Debit': '#A32727',\n", " ':Payment': '#3B80C4',\n", " ':Transfer': '#088472',\n", " ':Transaction': '#D10B4F',\n", " ':Email': '#EA5D1E',\n", " ':SSN': '#707070',\n", " ':Phone': '#4B4444',\n", "}" ], "metadata": { "id": "dw2P-XpfLCJY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher -u $neo4j_url -us $neo4j_user -pw $neo4j_password -co $colors\n", "CALL apoc.meta.graph()" ], "metadata": { "id": "BfFOTNkncMqp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Problem Definition\n" ], "metadata": { "id": "d-x29ClTPBnv" } }, { "cell_type": "markdown", "source": [ "## What is Fraud?\n", "Fraud occurs when an individual or group of individuals, or a business entity intentionally deceives another individual or business entity with misrepresentation of identity, products, services, or financial transactions and/or false promises with no intention of fulfilling them." ], "metadata": { "id": "79q5QJfcPMa6" } }, { "cell_type": "markdown", "source": [ "## Fraud Categories\n" ], "metadata": { "id": "naUmXhC-PQGR" } }, { "cell_type": "markdown", "source": [ "### First-party Fraud\n", "An individual, or group of individuals, misrepresent their identity or give false information when applying for a product or services to receive more favourable rates or when have no intention of repayment." ], "metadata": { "id": "edTfWFSAPUKF" } }, { "cell_type": "markdown", "source": [ "### Second-party Fraud\n", "An individual knowingly gives their identity or personal information to another individual to commit fraud or someone is perpetrating fraud in his behalf." ], "metadata": { "id": "Zr9sGs_9PYmH" } }, { "cell_type": "markdown", "source": [ "### Third-party Fraud\n", "An individual, or a group of individuals, create or use another person’s identity, or personal details, to open or takeover an account." ], "metadata": { "id": "o45K16ryPcBu" } }, { "cell_type": "markdown", "source": [ "## The dataset\n", "\n", "We will use Paysim dataset for the hands-on exercises. Paysim is a synthetic dataset that mimics real world mobile money transfer network.\n", "\n", "For more information on the dataset, please visit this [blog page](https://www.sisu.io/posts/paysim/)" ], "metadata": { "id": "dFje1N1cPq_9" } }, { "cell_type": "code", "source": [ "%%cypher\n", "CALL apoc.meta.graph()" ], "metadata": { "id": "AAeicV33PDXa" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "There are five types of transactions in this database. List all transaction types and corresponding metrics by iterating over all the transactions." ], "metadata": { "id": "Ux5tg_OzUvgT" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "viWCvG1MU632" } }, { "cell_type": "code", "source": [ "transaction_df = (spark.read\n", " .format('org.neo4j.spark.DataSource')\n", " .option('labels', ':Transaction')\n", " .load())\n", "\n", "transaction_df_count = transaction_df.count()\n", "\n", "transaction_df = (transaction_df.groupBy('')\n", " .count()\n", " .withColumnRenamed('', 'transaction'))\n", "\n", "transaction_df = (transaction_df\n", " .withColumn('transaction', transaction_df['transaction'].getItem(0))\n", " .withColumn('% transactions', transaction_df['count'] / transaction_df_count))\n", "\n", "transaction_df.show(truncate=False)" ], "metadata": { "id": "xsIlmR-EQLeb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "U2JUpe4NU6Jz" } }, { "cell_type": "code", "source": [ "transaction_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\"labels\": \"Transaction\"})\n", "\n", "transaction_ps_count = transaction_ps.count()[0] * 1.0\n", "\n", "transaction_ps = (transaction_ps.groupby([''])\n", " .size()\n", " .reset_index(name='% transactions'))\n", "\n", "transaction_ps = transaction_ps.rename(columns={'': 'label'})\n", "\n", "transaction_ps['% transactions'] = transaction_ps['% transactions'].astype(float).div(transaction_ps_count * 1.0)\n", "\n", "transaction_ps.label = [x[0] for x in transaction_ps.label.to_numpy()]\n", "\n", "transaction_ps" ], "metadata": { "id": "rfkeCU0PVFXx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "##### Plot the data\n", "You can also use Python libraries like [Ploty](https://plotly.com/python/) to plot results" ], "metadata": { "id": "t2RufDichKkQ" } }, { "cell_type": "code", "source": [ "import plotly.express as px\n", "\n", "# we use to_pandas() in order to transform the PySpark Pandas to a real Pandas Dataframe\n", "fig = px.pie(transaction_ps.to_pandas(), values='% transactions', names='label')\n", "\n", "fig.show()" ], "metadata": { "id": "Wn4HMYGVhV1t" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Exploit first-party Fraud\n", "\n", "Synthetic identity fraud and first party fraud can be identified by performing entity link analysis to detect identities linked to other identities via shared PII.\n", "\n", "There are three types of personally identifiable information (PII) in this dataset - SSN, Email and Phone Number\n", "\n", "Our hypothesis is that clients who share identifiers are suspicious and have a higher potential to commit fraud. However, all shared identifier links are not suspicious, for example, two people sharing an email address. Hence, we compute a fraud score based on shared PII relationships and label the top X percentile clients as fraudsters.\n", "\n", "We will first identify clients that share identifiers and create a new relationship between clients that share identifiers" ], "metadata": { "id": "VwPKEtu2QLlv" } }, { "cell_type": "markdown", "source": [ "## Enrich the dataset" ], "metadata": { "id": "E162NudWkW2n" } }, { "cell_type": "markdown", "source": [ "In order to perfrorm our investigation we want to enrich the base dataset by identifing clients that share PII." ], "metadata": { "id": "Y3MfFaKqH7Lm" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (c1:Client)-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(n)<-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]-(c2:Client)\n", "WHERE id(c1) < id(c2)\n", "RETURN c1.id, c2.id, count(*) AS freq\n", "ORDER BY freq DESC;" ], "metadata": { "id": "zJfRBlNNP9A1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now we can reuse the same Cypher query for creating our Dataframe and then use the Neo4j Spark Connector to create a new `SHARED_IDENTIFIERS` relationship betwen two clients:\n", "\n", "**(:Client)-[:SHARED_IDENTIFIERS]->(:Client)**\n", "\n" ], "metadata": { "id": "4J6d8U8bkMW_" } }, { "cell_type": "code", "source": [ "%%cypher\n", "// let's check if there relationships are in there\n", "MATCH (c:Client)-[r:SHARED_IDENTIFIERS]->(c2:Client)\n", "RETURN *\n", "LIMIT 10" ], "metadata": { "id": "sQ7Nf_IUQQ1J" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As you can see there are no relationships in the database" ], "metadata": { "id": "6zUwaYxgQhbC" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark" ], "metadata": { "id": "7CAiMFJ3LPmP" } }, { "cell_type": "code", "source": [ "shared_identifiers_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"query\", \"\"\"\n", " MATCH (c1:Client)-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(n)<-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]-(c2:Client)\n", " WHERE id(c1) < id(c2)\n", " RETURN c1.id AS source, c2.id AS target, count(*) AS freq\n", " \"\"\")\n", " .load())\n", "\n", "(shared_identifiers_df.write\n", " .format(\"org.neo4j.spark.DataSource\")\n", " .mode(\"Overwrite\")\n", " .option(\"relationship\", \"SHARED_IDENTIFIERS\")\n", " .option(\"relationship.save.strategy\", \"keys\")\n", " .option(\"relationship.source.labels\", \":Client\")\n", " .option(\"relationship.source.save.mode\", \"Overwrite\")\n", " .option(\"relationship.source.node.keys\", \"source:id\")\n", " .option(\"relationship.target.labels\", \":Client\")\n", " .option(\"relationship.target.node.keys\", \"target:id\")\n", " .option(\"relationship.target.save.mode\", \"Overwrite\")\n", " .option(\"relationship.properties\", \"freq:count\")\n", " .save())" ], "metadata": { "id": "36irb2nuj5Hi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Code in PySpark Pandas" ], "metadata": { "id": "h07CRHnlLU-G" } }, { "cell_type": "code", "source": [ "shared_identifiers_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\"query\": \"\"\"\n", " MATCH (c1:Client)-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(n)<-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]-(c2:Client)\n", " WHERE id(c1) < id(c2)\n", " RETURN c1.id AS source, c2.id AS target, count(*) AS freq\n", "\"\"\"})\n", "\n", "shared_identifiers_ps.spark.to_spark_io(format=\"org.neo4j.spark.DataSource\", mode=\"Overwrite\", options={\n", " \"relationship\": \"SHARED_IDENTIFIERS\",\n", " \"relationship.save.strategy\": \"keys\",\n", " \"relationship.source.labels\": \":Client\",\n", " \"relationship.source.save.mode\": \"Overwrite\",\n", " \"relationship.source.node.keys\": \"source:id\",\n", " \"relationship.target.labels\": \":Client\",\n", " \"relationship.target.node.keys\": \"target:id\",\n", " \"relationship.target.save.mode\": \"Overwrite\",\n", " \"relationship.properties\": \"freq:count\"\n", "})" ], "metadata": { "id": "vZ6So9x_MBY_" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "// let's check (again) if there relationships are in there\n", "MATCH (c:Client)-[r:SHARED_IDENTIFIERS]->(c2:Client)\n", "RETURN *\n", "LIMIT 10" ], "metadata": { "id": "A6iUPMYMQAzF" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Build Fraud detection workflow in Neo4j GDS\n", "\n", "We will construct a workflow with graph algorithms to detect fraud rings, score clients based on the number of common connections and rank them to select the top few suspicious clients and label them as fraudsters.\n", "\n", "1. Identify clusters of clients sharing PII using a community detection algorithm (Weakly Connected Components)\n", "2. Find similar clients within the clusters using pairwise similarity algorithms (Node Similarity)\n", "3. Calculate and assign fraud score to clients using centrality algorithms (Degree Centrality)\n", "4. Use computed fraud scores to label clients as potential fraudsters" ], "metadata": { "id": "Fy8roSDHQw3h" } }, { "cell_type": "markdown", "source": [ "## Identify groups of clients sharing PII (Fraud rings)\n", "\n", "Run Weakly connected components to find clusters of clients sharing PII.\n", "\n", "Weakly Connected Components is used to find groups of connected nodes, where all nodes in the same set form a connected component. WCC is often used early in an analysis understand the structure of a graph. More informaton here: [WCC documentation](https://neo4j.com/docs/graph-data-science/current/algorithms/wcc/)" ], "metadata": { "id": "bysMYQ23WFVl" } }, { "cell_type": "markdown", "source": [ "### Create a graph projection\n", "\n", "A central concept in the GDS library is the management of in-memory graphs. Graph algorithms run on a graph data model which is a projection of the Neo4j property graph data model. For more information, please click here: [Graph Management](https://neo4j.com/docs/graph-data-science/current/management-ops/)\n", "\n", "A projected graph can be stored in the catalog under a user-defined name. Using that name, the graph can be referred to by any algorithm in the library." ], "metadata": { "id": "39g6Fq1dTgLt" } }, { "cell_type": "markdown", "source": [ "Consider that the original Cypher query is the following:\n", "```cypher\n", "CALL gds.graph.project('wcc',\n", " {\n", " Client: {\n", " label: 'Client'\n", " }\n", " },\n", " {\n", " SHARED_IDENTIFIERS:{\n", " type: 'SHARED_IDENTIFIERS',\n", " orientation: 'UNDIRECTED',\n", " properties: {\n", " count: {\n", " property: 'count'\n", " }\n", " }\n", " }\n", " }\n", ") YIELD graphName,nodeCount,relationshipCount,projectMillis;\n", "```\n", "\n", "which will be translate into:" ], "metadata": { "id": "fXQwdpJfVGIq" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "-fkvDjHcUV5Z" } }, { "cell_type": "code", "source": [ "wcc_graph_proj_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.graph.project\")\n", " .option(\"gds.graphName\", \"wcc\")\n", " .option(\"gds.nodeProjection.Client.label\", \"Client\")\n", " .option(\"gds.relationshipProjection.SHARED_IDENTIFIERS.type\", \"SHARED_IDENTIFIERS\")\n", " .option(\"gds.relationshipProjection.SHARED_IDENTIFIERS.orientation\", \"UNDIRECTED\")\n", " .option(\"gds.relationshipProjection.SHARED_IDENTIFIERS.properties.count.property\", \"count\")\n", " .load())\n", "\n", "wcc_graph_proj_df.show(truncate=False)" ], "metadata": { "id": "4dBHWHh8R7US" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "cauhX4FRVYMy" } }, { "cell_type": "code", "source": [ "wcc_graph_proj_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.graph.project\",\n", " \"gds.graphName\": \"wcc\",\n", " \"gds.nodeProjection.Client.label\": \"Client\",\n", " \"gds.relationshipProjection.SHARED_IDENTIFIERS.type\": \"SHARED_IDENTIFIERS\",\n", " \"gds.relationshipProjection.SHARED_IDENTIFIERS.orientation\": \"UNDIRECTED\",\n", " \"gds.relationshipProjection.SHARED_IDENTIFIERS.properties.count.property\": \"count\"\n", "})\n", "\n", "wcc_graph_proj_ps" ], "metadata": { "id": "ZndBGbXPVeqU" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Run the WCC algorithm\n", "\n", "The original Cypher query is:\n", "\n", "```cypher\n", "CALL gds.wcc.stream('wcc',\n", " {\n", " nodeLabels: ['Client'],\n", " relationshipTypes: ['SHARED_IDENTIFIERS'],\n", " consecutiveIds: true\n", " }\n", ")\n", "YIELD nodeId, componentId\n", "RETURN gds.util.asNode(nodeId).id AS clientId, componentId\n", "ORDER BY componentId\n", "LIMIT 20\n", "```\n", "\n", "which is transate into:" ], "metadata": { "id": "P4oIKsUNn-ZH" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "Ygw7T3lSWbsQ" } }, { "cell_type": "code", "source": [ "# get the clients\n", "clients_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"labels\", \"Client\")\n", " .load())\n", "\n", "# invoke the gds wcc stream procedure\n", "wcc_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.wcc.stream\")\n", " .option(\"gds.graphName\", \"wcc\")\n", " .option(\"gds.nodeLabels\", \"['Client']\")\n", " .option(\"gds.relationshipTypes\", \"['SHARED_IDENTIFIERS']\")\n", " .option(\"gds.consecutiveIds\", \"true\")\n", " .load())\n", "\n", "# join the two dataframes and show id, componentId\n", "client_component_df = (clients_df.join(wcc_df, clients_df[\"\"] == wcc_df[\"nodeId\"], \"inner\")\n", " .select(\"id\", \"componentId\"))\n", "\n", "client_component_df.show(truncate=False)" ], "metadata": { "id": "6RtkJV9GWHnu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "xQI9gQNQYMWe" } }, { "cell_type": "code", "source": [ "# get the clients\n", "clients_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\"labels\": \"Client\"})\n", "\n", "# invoke the gds wcc stream procedure\n", "wcc_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.wcc.stream\",\n", " \"gds.graphName\": \"wcc\",\n", " \"gds.nodeLabels\": \"['Client']\",\n", " \"gds.relationshipTypes\": \"['SHARED_IDENTIFIERS']\",\n", " \"gds.consecutiveIds\": \"true\"\n", "})\n", "\n", "# join the two pandas df and show id, componentId\n", "client_component_ps = clients_ps.join(wcc_ps.set_index(\"nodeId\"), on=\"\")[[\"id\", \"componentId\"]]\n", "\n", "# we show only the first 20\n", "client_component_ps[:20]" ], "metadata": { "id": "AvlUnpIQYQsZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Write results to the database.\n", "Now that we identified clusters of clients sharing PII, we want to store these results back into the database by enriching the `Client` node.\n", "We'll add the component id of the cluster as `firstPartyFraudGroup` property" ], "metadata": { "id": "eNgmAuheZqfA" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "EKNcuklDaRKY" } }, { "cell_type": "code", "source": [ "(client_component_df\n", " .withColumnRenamed(\"componentId\", \"firstPartyFraudGroup\")\n", " .write\n", " .format(\"org.neo4j.spark.DataSource\")\n", " .mode(\"Overwrite\")\n", " .option(\"labels\", \"Client\")\n", " .option(\"node.keys\", \"id\")\n", " .save())" ], "metadata": { "id": "yQb0H-p7ZrRP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "-HK28CGoa5_p" } }, { "cell_type": "code", "source": [ "(client_component_ps\n", " .rename(columns={\"componentId\": \"firstPartyFraudGroup\"})\n", " .spark\n", " .to_spark_io(format=\"org.neo4j.spark.DataSource\", mode=\"Overwrite\", options={\n", " \"labels\": \"Client\",\n", " \"node.keys\": \"id\"\n", " }))" ], "metadata": { "id": "zJDS-0bta8_s" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%cypher\n", "// Visualize clusters with greater than 9 client nodes.\n", "MATCH (c:Client)\n", "WITH c.firstPartyFraudGroup AS fpGroupID, collect(c.id) AS fGroup\n", "WITH *, size(fGroup) AS groupSize WHERE groupSize >= 9\n", "WITH * LIMIT 1\n", "MATCH p=(c:Client)-[:HAS_SSN|HAS_EMAIL|HAS_PHONE]->()\n", "WHERE c.firstPartyFraudGroup = fpGroupID\n", "RETURN p" ], "metadata": { "id": "oOsrNUZocx21" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Compute pairwise similarity scores\n", "\n", "We use node similarity algorithm to find similar nodes based on the relationships to other nodes. Node similarity uses Jaccard metric ([Node Similarity](https://neo4j.com/docs/graph-data-science/current/algorithms/node-similarity/#algorithms-node-similarity))\n", "\n", "Node similarity algorithms work on bipartite graphs (two types of nodes and relationships between them). Here we project client nodes (one type) and three identifiers nodes (that are considered as second type) into memory." ], "metadata": { "id": "5CCwYp1FfoMU" } }, { "cell_type": "markdown", "source": [ "### Project the graph\n", "\n", "The original Cypher query is\n", "\n", "```cypher\n", "MATCH(c:Client) WHERE c.firstPartyFraudGroup is not NULL\n", "WITH collect(c) as clients\n", "MATCH(n) WHERE n:Email OR n:Phone OR n:SSN\n", "WITH clients, collect(n) as identifiers\n", "WITH clients + identifiers as nodes\n", "\n", "MATCH(c:Client) -[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(id)\n", "WHERE c.firstPartyFraudGroup is not NULL\n", "WITH nodes, collect({source: c, target: id}) as relationships\n", "\n", "CALL gds.graph.project.cypher('similarity',\n", " \"UNWIND $nodes as n RETURN id(n) AS id,labels(n) AS labels\",\n", " \"UNWIND $relationships as r RETURN id(r['source']) AS source, id(r['target']) AS target, 'HAS_IDENTIFIER' as type\",\n", " { parameters: {nodes: nodes, relationships: relationships}}\n", ")\n", "YIELD graphName, nodeCount, relationshipCount, projectMillis\n", "RETURN graphName, nodeCount, relationshipCount, projectMillis\n", "```\n", "\n", "Which is translated into" ], "metadata": { "id": "aLGQxFtpnQHa" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "aPSY4htNgLVG" } }, { "cell_type": "code", "source": [ "similarity_graph_proj_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.graph.project.cypher\")\n", " .option(\"gds.graphName\", \"similarity\")\n", " .option(\"gds.nodeQuery\", \"\"\"\n", " MATCH (n)\n", " WHERE (n:Client AND n.firstPartyFraudGroup is not NULL) OR n:Email OR n:Phone OR n:SSN\n", " RETURN id(n) AS id, labels(n) AS labels\n", " \"\"\")\n", " .option(\"gds.relationshipQuery\", \"\"\"\n", " MATCH (s:Client)-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(t)\n", " WHERE s.firstPartyFraudGroup is not NULL\n", " RETURN id(s) AS source, id(t) AS target, 'HAS_IDENTIFIER' as type\n", " \"\"\")\n", " .load())\n", "\n", "similarity_graph_proj_df.show(truncate=False)" ], "metadata": { "id": "eNvVkJTRfuqM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "PcJCKDDhmbtr" } }, { "cell_type": "code", "source": [ "similarity_graph_proj_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.graph.project.cypher\",\n", " \"gds.graphName\": \"similarity\",\n", " \"gds.nodeQuery\": \"\"\"\n", " MATCH (n)\n", " WHERE (n:Client AND n.firstPartyFraudGroup is not NULL) OR n:Email OR n:Phone OR n:SSN\n", " RETURN id(n) AS id, labels(n) AS labels\n", " \"\"\",\n", " \"gds.relationshipQuery\": \"\"\"\n", " MATCH (s:Client)-[:HAS_EMAIL|:HAS_PHONE|:HAS_SSN]->(t)\n", " WHERE s.firstPartyFraudGroup is not NULL\n", " RETURN id(s) AS source, id(t) AS target, 'HAS_IDENTIFIER' as type\n", " \"\"\"\n", "})\n", "\n", "similarity_graph_proj_ps" ], "metadata": { "id": "odriamHwmiEQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Compute the node similarity\n", "\n", "We can mutate in-memory graph by writing outputs from the algorithm as node or relationship properties.\n", "\n", "In this particular case all the procedures with `mutate` and `write` suffix are not supported from the Neo4j Spark Connector, in this case we'll write a Cypher query:\n" ], "metadata": { "id": "TfnVREvvmzG4" } }, { "cell_type": "code", "source": [ "%%cypher\n", "CALL gds.nodeSimilarity.mutate('similarity',\n", " {\n", " topK:15,\n", " mutateProperty: 'jaccardScore',\n", " mutateRelationshipType:'SIMILAR_TO'\n", " }\n", ");" ], "metadata": { "id": "E0_WoCLlnIr6" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Mutate mode is very fast compared to write mode and it helps in optimizing algorithm execution times, then we write back the property from in-memory graph to the database and use it for further analysis:" ], "metadata": { "id": "JnC3C7urPyN_" } }, { "cell_type": "code", "source": [ "%%cypher\n", "CALL gds.graph.writeRelationship('similarity', 'SIMILAR_TO', 'jaccardScore');" ], "metadata": { "id": "xajV8enLPvIN" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Exercise: Calculate First-party Fraud Score\n", "\n", "We compute first party fraud score using weighted degree centrality algorithm.\n", "\n", "In this step, we compute and assign fraud score (`firstPartyFraudScore`) to clients in the clusters identified in previous steps based on `SIMILAR_TO` relationships weighted by `jaccardScore`\n", "\n", "Weighted degree centrality algorithm add up similarity scores (`jaccardScore`) on the incoming `SIMILAR_TO` relationships for a given node in a cluster and assign the sum as the corresponding `firstPartyFraudScore`. This score represents clients who are similar to many others in the cluster in terms of sharing identifiers. Higher `firstPartyFraudScore` represents greater potential for committing fraud." ], "metadata": { "id": "qpXXoWeIQk9U" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark" ], "metadata": { "id": "fAp_acV-RBOu" } }, { "cell_type": "code", "source": [ "# invoke the gds.degree.stream procedure" ], "metadata": { "id": "ZPAHXvT6Qouy" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "similarity_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.degree.stream\")\n", " .option(\"gds.graphName\", \"similarity\")\n", " .option(\"gds.nodeLabels\", \"['Client']\")\n", " .option(\"gds.relationshipTypes\", \"['SIMILAR_TO']\")\n", " .option(\"gds.relationshipWeightProperty\", \"jaccardScore\")\n", " .load())\n", "\n", "# join the two dataframes and show id, score\n", "client_similarity_df = (clients_df.join(similarity_df, clients_df[\"\"] == similarity_df[\"nodeId\"], \"inner\")\n", " .select(\"id\", \"score\")\n", " .withColumnRenamed(\"score\", \"firstPartyFraudScore\"))\n", "\n", "# write the results back to the database\n", "(client_similarity_df.write.format('org.neo4j.spark.DataSource')\n", " .mode(\"Overwrite\")\n", " .option(\"labels\", \"Client\")\n", " .option(\"node.keys\", \"id\")\n", " .save())\n", "```\n", "\n", "
" ], "metadata": { "id": "1qKbMW3FSwLE" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark Pandas\n", "\n" ], "metadata": { "id": "na2h6VLov-ZA" } }, { "cell_type": "code", "source": [ "# invoke the gds.degree.stream procedure" ], "metadata": { "id": "7ObUlkdmwDzi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "similarity_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.degree.stream\",\n", " \"gds.graphName\": \"similarity\",\n", " \"gds.nodeLabels\": \"['Client']\",\n", " \"gds.relationshipTypes\": \"['SIMILAR_TO']\",\n", " \"gds.relationshipWeightProperty\": \"jaccardScore\"\n", "})\n", "\n", "# join the two pandas df and show id, score\n", "client_similarity_ps = (clients_ps.join(similarity_ps.set_index(\"nodeId\"), on=\"\")[[\"id\", \"score\"]]\n", " .rename(columns={\"score\": \"firstPartyFraudScore\"}))\n", "\n", "# write the results back to the database\n", "client_similarity_ps.spark.to_spark_io(format=\"org.neo4j.spark.DataSource\", mode=\"Overwrite\", options={\n", " \"labels\": \"Client\",\n", " \"node.keys\": \"id\"\n", "})\n", "```\n", "\n", "
" ], "metadata": { "id": "m8Cy4NtsS3nQ" } }, { "cell_type": "markdown", "source": [ "### Verifiy the result\n", "\n", "We expect that:\n", "- `similarity_df`/`similarity_ps`\n", " - has two columns:\n", " - `nodeId` of long type\n", " - `score` of double type\n", " - a count of **9134** rows\n", "- `client_similarity_df`/`client_similarity_ps`\n", " - has two columns:\n", " - `id` of long type\n", " - `score` of double type\n", " - a count of 2433 rows" ], "metadata": { "id": "I_TKq9TmU3cS" } }, { "cell_type": "markdown", "source": [ "#### Test PySpark Dataframe" ], "metadata": { "id": "Dow9PWAtxBPu" } }, { "cell_type": "code", "source": [ "assert StructType([StructField(\"nodeId\", LongType()), StructField(\"score\", DoubleType())]) == similarity_df.schema\n", "assert 9134 == similarity_df.count()\n", "\n", "assert StructType([StructField(\"id\", StringType()), StructField(\"firstPartyFraudScore\", DoubleType())]) == client_similarity_df.schema\n", "assert 2433 == client_similarity_df.count()\n", "print(\"All assertion are successfuly satisfied.\")" ], "metadata": { "id": "5E6fUHwZU7PN" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Test PySpark Pandas" ], "metadata": { "id": "feuIfTCqxGKt" } }, { "cell_type": "code", "source": [ "assert StructType([StructField(\"nodeId\", LongType()), StructField(\"score\", DoubleType())]) == similarity_ps.to_spark().schema\n", "assert 9134 == similarity_ps.count()[0]\n", "\n", "assert StructType([StructField(\"id\", StringType()), StructField(\"firstPartyFraudScore\", DoubleType())]) == client_similarity_ps.to_spark().schema\n", "assert 2433 == client_similarity_ps.count()[0]\n", "print(\"All assertion are successfuly satisfied.\")" ], "metadata": { "id": "OHel2EYoxFpC" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We find clients with first-party fraud score greater than some threshold (X) and label those top X percentile clients as fraudsters. In this example, using 95th percentile as a threshold, we set a property FirstPartyFraudster on the Client node." ], "metadata": { "id": "YBhfX5DYXsom" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (c:Client)\n", "WHERE c.firstPartyFraudScore IS NOT NULL\n", "WITH percentileCont(c.firstPartyFraudScore, 0.95) AS firstPartyFraudThreshold\n", "MATCH (c:Client)\n", "WHERE c.firstPartyFraudScore > firstPartyFraudThreshold\n", "SET c:FirstPartyFraudster" ], "metadata": { "id": "bSb4rZ-aXvYV" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Second-party Fraud / Money Mules\n", "\n", "The first step is to find out clients who weren't identified as first party fraudsters but they transact with first party fraudsters." ], "metadata": { "id": "WyKyntZ_YC1g" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH p=(:Client:FirstPartyFraudster)-[]-(:Transaction)-[]-(c:Client)\n", "WHERE NOT c:FirstPartyFraudster\n", "RETURN p\n", "LIMIT 50" ], "metadata": { "id": "C94Yg_psYTRW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Also, lets find out what types of transactions do these Clients perform with first party fraudsters" ], "metadata": { "id": "Tm9WYZnwsQif" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (:Client:FirstPartyFraudster)-[]-(txn:Transaction)-[]-(c:Client)\n", "WHERE NOT c:FirstPartyFraudster\n", "UNWIND labels(txn) AS transactionType\n", "RETURN transactionType, count(*) AS freq" ], "metadata": { "id": "x2P9Y7IzY2Vl" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Create new relationships\n", "\n", "Let’s go ahead and create `TRANSFER_TO` relationships between clients with `firstPartyFraudster` tags and other clients. Also add the total amount from all such transactions as a property on `TRANSFER_TO` relationships.\n", "\n", "Since the total amount transferred from a fraudster to a client and the total amount transferred in the reverse direction are not the same, we have to create relationships in two separate queries.\n", "\n", "* `TRANSFER_TO` relationship from a fraudster to a client (look at the directions in queries)\n", "* Add `SecondPartyFraudSuspect` tag to these clients" ], "metadata": { "id": "f2HTap1vuLBW" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (c1:FirstPartyFraudster)-[]->(t:Transaction)-[]->(c2:Client)\n", "WHERE NOT c2:FirstPartyFraudster\n", "WITH c1, c2, sum(t.amount) AS totalAmount\n", "SET c2:SecondPartyFraudSuspect\n", "CREATE (c1)-[:TRANSFER_TO {amount:totalAmount}]->(c2)" ], "metadata": { "id": "lfGkbJk2ueau" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "* `TRANSFER_TO` relationship from a client to a fraudster." ], "metadata": { "id": "6yO2wJP6uksA" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH (c1:FirstPartyFraudster)<-[]-(t:Transaction)<-[]-(c2:Client)\n", "WHERE NOT c2:FirstPartyFraudster\n", "WITH c1, c2, sum(t.amount) AS totalAmount\n", "SET c2:SecondPartyFraudSuspect\n", "CREATE (c1)<-[:TRANSFER_TO {amount:totalAmount}]-(c2);" ], "metadata": { "id": "WUeeP4MhujzM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Visualize newly created `TRANSFER_TO` relationships" ], "metadata": { "id": "68rcS85lutUx" } }, { "cell_type": "code", "source": [ "%%cypher\n", "MATCH p=(:Client:FirstPartyFraudster)-[:TRANSFER_TO]-(c:Client)\n", "WHERE NOT c:FirstPartyFraudster\n", "RETURN p\n", "LIMIT 50" ], "metadata": { "id": "LC5kSXOGuu1w" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Goal\n", "\n", "Our objective is to find out clients who may have supported the first party fraudsters and were not identified as potential first party fraudsters.\n", "\n", "Our hypothesis is that clients who perform transactions of type `Transfer` where they either send or receive money from first party fraudsters are flagged as suspects for second party fraud.\n", "\n", "To identify such clients, make use of `TRANSFER_TO` relationships and use this recipe:\n", "\n", "* Use **WCC** (community detection) to identify networks of clients who are connected to first party fraudsters\n", "* Use **PageRank** (centrality) to score clients based on their influence in terms of the amount of money transferred to/from fraudsters\n", "* Assign risk score (`secondPartyFraudScore`) to these clients" ], "metadata": { "id": "6swyguhivCEL" } }, { "cell_type": "markdown", "source": [ "## Exercise: Project the graph\n", "\n", "Let’s use native projection and create an in-memory graph with Client nodes and TRANSFER_TO relationships.\n", "\n", "We want to project:\n", "* `Client` for `nodeProjection`\n", "* `TRANSFER_TO` for `relationshipProjection`\n", "* for the configuration we want to set `relationshipProperties` to `amount`" ], "metadata": { "id": "ILntBuwxvXCz" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark" ], "metadata": { "id": "zgKufo0Z2C8e" } }, { "cell_type": "code", "source": [ "second_party_graph_proj_df = # insert your PySpark code here" ], "metadata": { "id": "JmhR2nhwvWRu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "second_party_graph_proj_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.graph.project\")\n", " .option(\"gds.graphName\", \"SecondPartyFraudNetwork\")\n", " .option(\"gds.nodeProjection\", \"Client\")\n", " .option(\"gds.relationshipProjection\", \"TRANSFER_TO\")\n", " .option(\"gds.configuration.relationshipProperties\", \"amount\")\n", " .load())\n", "```\n", "\n", "
" ], "metadata": { "id": "Ab6GEhJ_Tsrx" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark Pandas" ], "metadata": { "id": "QHnI8W5J2Gcp" } }, { "cell_type": "code", "source": [ "second_party_graph_proj_ps = # insert your PySpark Pandas code here" ], "metadata": { "id": "kD6vXaab3JdM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "second_party_graph_proj_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.graph.project\",\n", " \"gds.graphName\": \"SecondPartyFraudNetwork\",\n", " \"gds.nodeProjection\": \"Client\",\n", " \"gds.relationshipProjection\": \"TRANSFER_TO\",\n", " \"gds.configuration.relationshipProperties\": \"amount\"\n", "})\n", "```\n", "\n", "
" ], "metadata": { "id": "c0VjV2rBTy10" } }, { "cell_type": "markdown", "source": [ "### Verify the projection" ], "metadata": { "id": "ZVul68Wi162b" } }, { "cell_type": "markdown", "source": [ "#### Test PySpark Dataframe" ], "metadata": { "id": "jlOuUGA833ny" } }, { "cell_type": "code", "source": [ "second_party_graph_proj_df.cache()\n", "\n", "first_row = [\n", " {\n", " \"node\": list(row[\"nodeProjection\"].keys())[0],\n", " \"rel\": list(row[\"relationshipProjection\"].keys())[0],\n", " \"graphName\": row[\"graphName\"],\n", " \"nodeCount\": row[\"nodeCount\"],\n", " \"relCount\": row[\"relationshipCount\"]\n", " } for row in second_party_graph_proj_df.collect()\n", "][0]\n", "\n", "assert first_row == {\"node\": \"Client\", \"rel\": \"TRANSFER_TO\", \"graphName\": \"SecondPartyFraudNetwork\", \"nodeCount\": 2433, \"relCount\": 367}\n", "print(\"All assertion are successfuly satisfied.\")" ], "metadata": { "id": "hkjx6Nu339wW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Test PySpark Pandas" ], "metadata": { "id": "37mDMX8G4EPE" } }, { "cell_type": "code", "source": [ "second_party_graph_proj_ps_df = second_party_graph_proj_ps.to_spark()\n", "\n", "second_party_graph_proj_ps_df.cache()\n", "\n", "first_row = [\n", " {\n", " \"node\": list(row[\"nodeProjection\"].keys())[0],\n", " \"rel\": list(row[\"relationshipProjection\"].keys())[0],\n", " \"graphName\": row[\"graphName\"],\n", " \"nodeCount\": row[\"nodeCount\"],\n", " \"relCount\": row[\"relationshipCount\"]\n", " } for row in second_party_graph_proj_ps_df.collect()\n", "][0]\n", "\n", "assert first_row == {\"node\": \"Client\", \"rel\": \"TRANSFER_TO\", \"graphName\": \"SecondPartyFraudNetwork\", \"nodeCount\": 2433, \"relCount\": 367}\n", "print(\"All assertion are successfuly satisfied.\")" ], "metadata": { "id": "spF535cp4IcR" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Check clusters with more than one clients\n", "We will see if there are any clusters with more than one clients in them and if there are, then we should add a tag `secondPartyFraudGroup` to find them later using local queries." ], "metadata": { "id": "32FL1L324QxV" } }, { "cell_type": "markdown", "source": [ "### Code in PySpark" ], "metadata": { "id": "UljCEN-s40mc" } }, { "cell_type": "code", "source": [ "# invoke gds.wcc.stream on\n", "second_party_wcc_df = (spark.read.format(\"org.neo4j.spark.DataSource\")\n", " .option(\"gds\", \"gds.wcc.stream\")\n", " .option(\"gds.graphName\", \"SecondPartyFraudNetwork\")\n", " .load())\n", "\n", "# join the two dataframes aggregate by componentId\n", "# and filtering for clusters with a size greater then 1\n", "second_party_client_component_df = (clients_df.join(second_party_wcc_df, clients_df[\"\"] == second_party_wcc_df[\"nodeId\"], \"inner\")\n", " .groupBy(\"componentId\")\n", " .agg(count(\"*\").alias(\"count\"), collect_list(clients_df[\"id\"]).alias(\"cluster\"))\n", " .filter(size(\"cluster\") > 1))\n", "\n", "second_party_client_component_df = (second_party_client_component_df\n", " .withColumn(\"id\", explode(col(\"cluster\")))\n", " .select(\"id\", \"componentId\"))\n", "\n", "\n", "second_party_client_component_df.show(truncate=False)" ], "metadata": { "id": "jKK3m8rH150-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Code in PySpark Pandas" ], "metadata": { "id": "cL0e6QHDZZZi" } }, { "cell_type": "code", "source": [ "second_party_wcc_ps = ps.read_spark_io(format=\"org.neo4j.spark.DataSource\", options={\n", " \"gds\": \"gds.wcc.stream\",\n", " \"gds.graphName\": \"SecondPartyFraudNetwork\"\n", "})\n", "\n", "second_party_client_component_ps = (clients_ps.join(second_party_wcc_ps.set_index(\"nodeId\"), on=\"\")\n", " .groupby([\"componentId\"])\n", " .id\n", " .apply(list)\n", " .reset_index()\n", " .rename(columns={'id': 'cluster'})\n", ")\n", "\n", "second_party_client_component_ps = second_party_client_component_ps[second_party_client_component_ps[\"cluster\"].apply(lambda a: len(a)) > 1]\n", "\n", "second_party_client_component_ps = (second_party_client_component_ps\n", " .explode(\"cluster\")\n", " .rename(columns={'cluster': 'id'})\n", " [[\"id\", \"componentId\"]])\n", "\n", "second_party_client_component_ps[:20]" ], "metadata": { "id": "49OxrIuTZdb-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Exercise: Write the results back to database\n", "\n", "Write a Spark job that given the columns (`id`, `componentId`) retrieve the `Client` by the `id` and set for the node a property `secondPartyFraudGroup` with the value of `componentId`" ], "metadata": { "id": "9IsIbHD7b3cL" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark" ], "metadata": { "id": "1U2kzCyWdv8l" } }, { "cell_type": "code", "source": [ "# write your PySpark code here" ], "metadata": { "id": "PTid4oiqg_lu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "(second_party_client_component_df\n", " .withColumnRenamed(\"componentId\", \"secondPartyFraudGroup\")\n", " .write\n", " .format(\"org.neo4j.spark.DataSource\")\n", " .mode(\"Overwrite\")\n", " .option(\"labels\", \"Client\")\n", " .option(\"node.keys\", \"id\")\n", " .save())\n", "```\n", "\n", "
" ], "metadata": { "id": "BQB9o9JCYuOs" } }, { "cell_type": "markdown", "source": [ "#### Code in PySpark Pandas" ], "metadata": { "id": "QgUs-jzIdyWJ" } }, { "cell_type": "code", "source": [ "# write your PySpark Pandas code here" ], "metadata": { "id": "Copd3BPHduvJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "
\n", "\n", "Show a possible solution\n", "\n", "\n", "```python\n", "(client_component_ps\n", " .rename(columns={\"componentId\": \"secondPartyFraudGroup\"})\n", " .spark\n", " .to_spark_io(format=\"org.neo4j.spark.DataSource\", mode=\"Overwrite\", options={\n", " \"labels\": \"Client\",\n", " \"node.keys\": \"id\"\n", " }))\n", "```\n", "\n", "
" ], "metadata": { "id": "7uiMcrNHY3GS" } }, { "cell_type": "markdown", "source": [ "#### Verify the Spark job result" ], "metadata": { "id": "0mi9x4ragzyX" } }, { "cell_type": "code", "source": [ "secondPartyFraudGroup_check_count = %cypher -u $neo4j_url MATCH (c:Client) WHERE c.secondPartyFraudGroup IS NOT NULL RETURN count(c) AS count\n", "assert 2433 == secondPartyFraudGroup_check_count['count'][0]\n", "print(\"All assertion are successfuly satisfied.\")" ], "metadata": { "id": "UM3TB_KxgzSh" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "xSkwu_ovYaI7" }, "execution_count": null, "outputs": [] } ] } ================================================ FILE: jreleaser.yml ================================================ # Generated with JReleaser 1.17.0 at 2025-04-28T13:27:24.943485+01:00 project: name: neo4j-spark-connector description: Neo4j Connector for Spark authors: - Connectors Team license: Apache-2.0 copyright: Neo4j, Inc. links: homepage: https://github.com/neo4j/neo4j-spark-connector languages: java: groupId: org.neo4j release: github: owner: neo4j name: neo4j-spark-connector tagName: "{{projectVersion}}" update: enabled: true sections: - TITLE - BODY - ASSETS artifacts: false files: true checksums: true changelog: formatted: ALWAYS preset: conventional-commits skipMergeCommits: true format: "- {{commitShortHash}} {{conventionalCommitDescription}}" links: true labelers: - title: ci label: ci excludeLabels: - ci hide: contributors: - '[bot]' - GitHub checksum: name: 'neo4j-connector-apache-spark-{{projectVersion}}_checksums.txt' algorithms: - SHA_256 files: true artifacts: true individual: true files: globs: - pattern: artifacts/*.jar - pattern: artifacts/*.zip extraProperties: skipRelease: "true" assemble: archive: zip: active: ALWAYS exported: true stereotype: NONE archiveName: neo4j-connector-apache-spark-{{projectVersion}} distributionType: BINARY attachPlatform: false formats: - ZIP fileSets: - input: '{{basedir}}/artifacts' output: . includes: - '{{projectName}}-{{projectVersion}}*.jar' templateDirectory: spark-3/src/jreleaser/assemblers/zip hooks: script: before: - filter: includes: [ "assemble" ] shell: BASH run: | rm -rf artifacts - filter: includes: [ "assemble" ] matrix: vars: scala: [ "2.12", "2.13" ] continueOnError: false verbose: true shell: BASH run: | mkdir artifacts || true ./maven-release.sh deploy {{matrix.scala}} default::file://{{basedir}}/target/{{matrix.scala}}/maven-artifacts cp -r {{basedir}}/target/{{matrix.scala}}/maven-artifacts artifacts/ cp -r {{basedir}}/spark-3/target/{{projectName}}*.zip artifacts/ cp -r {{basedir}}/spark-3/target/{{projectName}}*.jar artifacts/ signing: active: ALWAYS mode: COMMAND command: homeDir: '~/.gnupg' deploy: maven: active: ALWAYS mavenCentral: artifacts: active: ALWAYS url: https://central.sonatype.com/api/v1/publisher applyMavenCentralRules: true namespace: org.neo4j verifyPom: false stagingRepositories: - ./artifacts/maven-artifacts announce: slack: channels: - '#release' - '#team-spark' message: ':tada: @release-pm {{projectNameCapitalized}} {{projectVersion}} has been released! {{releaseNotesUrl}}' ================================================ FILE: maven-release.sh ================================================ #!/bin/bash set -eEuxo pipefail if [[ $# -lt 2 ]] ; then echo "Usage ./maven-release.sh []" exit 1 fi exit_script() { echo "Process terminated cleaning up resources" mv -f pom.xml.bak pom.xml mv -f common/pom.xml.bak common/pom.xml mv -f test-support/pom.xml.bak test-support/pom.xml mv -f spark-3/pom.xml.bak spark-3/pom.xml trap - SIGINT SIGTERM # clear the trap kill -- -$$ || true # Sends SIGTERM to child/sub processes } mvn_evaluate() { local expression expression="${1}" ./mvnw -B help:evaluate -Dexpression="${expression}" --quiet -DforceStdout } trap exit_script SIGINT SIGTERM GOAL=$1 SCALA_VERSION=$2 SPARK_VERSION=3 if [[ $# -eq 3 ]] ; then ALT_DEPLOYMENT_REPOSITORY="-DaltDeploymentRepository=$3" else ALT_DEPLOYMENT_REPOSITORY="" fi case $(sed --help 2>&1) in *GNU*) sed_i () { sed -i "$@"; };; *) sed_i () { sed -i '' "$@"; };; esac PROJECT_VERSION=$(mvn_evaluate "project.version") SPARK_PACKAGES_VERSION="${PROJECT_VERSION}-s_$SCALA_VERSION" # backup files cp pom.xml pom.xml.bak cp common/pom.xml common/pom.xml.bak cp test-support/pom.xml test-support/pom.xml.bak cp spark-3/pom.xml spark-3/pom.xml.bak ./mvnw -B versions:set -DnewVersion=${PROJECT_VERSION}_for_spark_${SPARK_VERSION} -DgenerateBackupPoms=false # replace pom files with target scala version sed_i "s/neo4j-connector-apache-spark_parent<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_parent<\/artifactId>/" pom.xml sed_i "s/neo4j-connector-apache-spark_parent<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_parent<\/artifactId>/" "test-support/pom.xml" sed_i "s/neo4j-connector-apache-spark_test-support<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_test-support<\/artifactId>/" "test-support/pom.xml" sed_i "s/neo4j-connector-apache-spark_common<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_common<\/artifactId>/" "common/pom.xml" sed_i "s/neo4j-connector-apache-spark_parent<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_parent<\/artifactId>/" "common/pom.xml" sed_i "s/neo4j-connector-apache-spark_test-support<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_test-support<\/artifactId>/" "common/pom.xml" sed_i "s/neo4j-connector-apache-spark<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}<\/artifactId>/" "spark-3/pom.xml" sed_i "s/neo4j-connector-apache-spark_parent<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_parent<\/artifactId>/" "spark-3/pom.xml" sed_i "s/neo4j-connector-apache-spark_common<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_common<\/artifactId>/" "spark-3/pom.xml" sed_i "s/neo4j-connector-apache-spark_test-support<\/artifactId>/neo4j-connector-apache-spark_${SCALA_VERSION}_test-support<\/artifactId>/" "spark-3/pom.xml" sed_i "s//${SPARK_PACKAGES_VERSION}<\/spark-packages.version>/" "spark-3/pom.xml" # build ./mvnw -B clean "${GOAL}" -Dscala-"${SCALA_VERSION}" -DskipTests ${ALT_DEPLOYMENT_REPOSITORY} if [ ! ${CI:-false} = true ]; then exit_script fi ================================================ FILE: mvnw ================================================ #!/bin/sh # ---------------------------------------------------------------------------- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- # Apache Maven Wrapper startup batch script, version 3.3.2 # # Optional ENV vars # ----------------- # JAVA_HOME - location of a JDK home dir, required when download maven via java source # MVNW_REPOURL - repo url base for downloading maven distribution # MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven # MVNW_VERBOSE - true: enable verbose log; debug: trace the mvnw script; others: silence the output # ---------------------------------------------------------------------------- set -euf [ "${MVNW_VERBOSE-}" != debug ] || set -x # OS specific support. native_path() { printf %s\\n "$1"; } case "$(uname)" in CYGWIN* | MINGW*) [ -z "${JAVA_HOME-}" ] || JAVA_HOME="$(cygpath --unix "$JAVA_HOME")" native_path() { cygpath --path --windows "$1"; } ;; esac # set JAVACMD and JAVACCMD set_java_home() { # For Cygwin and MinGW, ensure paths are in Unix format before anything is touched if [ -n "${JAVA_HOME-}" ]; then if [ -x "$JAVA_HOME/jre/sh/java" ]; then # IBM's JDK on AIX uses strange locations for the executables JAVACMD="$JAVA_HOME/jre/sh/java" JAVACCMD="$JAVA_HOME/jre/sh/javac" else JAVACMD="$JAVA_HOME/bin/java" JAVACCMD="$JAVA_HOME/bin/javac" if [ ! -x "$JAVACMD" ] || [ ! -x "$JAVACCMD" ]; then echo "The JAVA_HOME environment variable is not defined correctly, so mvnw cannot run." >&2 echo "JAVA_HOME is set to \"$JAVA_HOME\", but \"\$JAVA_HOME/bin/java\" or \"\$JAVA_HOME/bin/javac\" does not exist." >&2 return 1 fi fi else JAVACMD="$( 'set' +e 'unset' -f command 2>/dev/null 'command' -v java )" || : JAVACCMD="$( 'set' +e 'unset' -f command 2>/dev/null 'command' -v javac )" || : if [ ! -x "${JAVACMD-}" ] || [ ! -x "${JAVACCMD-}" ]; then echo "The java/javac command does not exist in PATH nor is JAVA_HOME set, so mvnw cannot run." >&2 return 1 fi fi } # hash string like Java String::hashCode hash_string() { str="${1:-}" h=0 while [ -n "$str" ]; do char="${str%"${str#?}"}" h=$(((h * 31 + $(LC_CTYPE=C printf %d "'$char")) % 4294967296)) str="${str#?}" done printf %x\\n $h } verbose() { :; } [ "${MVNW_VERBOSE-}" != true ] || verbose() { printf %s\\n "${1-}"; } die() { printf %s\\n "$1" >&2 exit 1 } trim() { # MWRAPPER-139: # Trims trailing and leading whitespace, carriage returns, tabs, and linefeeds. # Needed for removing poorly interpreted newline sequences when running in more # exotic environments such as mingw bash on Windows. printf "%s" "${1}" | tr -d '[:space:]' } # parse distributionUrl and optional distributionSha256Sum, requires .mvn/wrapper/maven-wrapper.properties while IFS="=" read -r key value; do case "${key-}" in distributionUrl) distributionUrl=$(trim "${value-}") ;; distributionSha256Sum) distributionSha256Sum=$(trim "${value-}") ;; esac done <"${0%/*}/.mvn/wrapper/maven-wrapper.properties" [ -n "${distributionUrl-}" ] || die "cannot read distributionUrl property in ${0%/*}/.mvn/wrapper/maven-wrapper.properties" case "${distributionUrl##*/}" in maven-mvnd-*bin.*) MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ case "${PROCESSOR_ARCHITECTURE-}${PROCESSOR_ARCHITEW6432-}:$(uname -a)" in *AMD64:CYGWIN* | *AMD64:MINGW*) distributionPlatform=windows-amd64 ;; :Darwin*x86_64) distributionPlatform=darwin-amd64 ;; :Darwin*arm64) distributionPlatform=darwin-aarch64 ;; :Linux*x86_64*) distributionPlatform=linux-amd64 ;; *) echo "Cannot detect native platform for mvnd on $(uname)-$(uname -m), use pure java version" >&2 distributionPlatform=linux-amd64 ;; esac distributionUrl="${distributionUrl%-bin.*}-$distributionPlatform.zip" ;; maven-mvnd-*) MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ ;; *) MVN_CMD="mvn${0##*/mvnw}" _MVNW_REPO_PATTERN=/org/apache/maven/ ;; esac # apply MVNW_REPOURL and calculate MAVEN_HOME # maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ [ -z "${MVNW_REPOURL-}" ] || distributionUrl="$MVNW_REPOURL$_MVNW_REPO_PATTERN${distributionUrl#*"$_MVNW_REPO_PATTERN"}" distributionUrlName="${distributionUrl##*/}" distributionUrlNameMain="${distributionUrlName%.*}" distributionUrlNameMain="${distributionUrlNameMain%-bin}" MAVEN_USER_HOME="${MAVEN_USER_HOME:-${HOME}/.m2}" MAVEN_HOME="${MAVEN_USER_HOME}/wrapper/dists/${distributionUrlNameMain-}/$(hash_string "$distributionUrl")" exec_maven() { unset MVNW_VERBOSE MVNW_USERNAME MVNW_PASSWORD MVNW_REPOURL || : exec "$MAVEN_HOME/bin/$MVN_CMD" "$@" || die "cannot exec $MAVEN_HOME/bin/$MVN_CMD" } if [ -d "$MAVEN_HOME" ]; then verbose "found existing MAVEN_HOME at $MAVEN_HOME" exec_maven "$@" fi case "${distributionUrl-}" in *?-bin.zip | *?maven-mvnd-?*-?*.zip) ;; *) die "distributionUrl is not valid, must match *-bin.zip or maven-mvnd-*.zip, but found '${distributionUrl-}'" ;; esac # prepare tmp dir if TMP_DOWNLOAD_DIR="$(mktemp -d)" && [ -d "$TMP_DOWNLOAD_DIR" ]; then clean() { rm -rf -- "$TMP_DOWNLOAD_DIR"; } trap clean HUP INT TERM EXIT else die "cannot create temp dir" fi mkdir -p -- "${MAVEN_HOME%/*}" # Download and Install Apache Maven verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." verbose "Downloading from: $distributionUrl" verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" # select .zip or .tar.gz if ! command -v unzip >/dev/null; then distributionUrl="${distributionUrl%.zip}.tar.gz" distributionUrlName="${distributionUrl##*/}" fi # verbose opt __MVNW_QUIET_WGET=--quiet __MVNW_QUIET_CURL=--silent __MVNW_QUIET_UNZIP=-q __MVNW_QUIET_TAR='' [ "${MVNW_VERBOSE-}" != true ] || __MVNW_QUIET_WGET='' __MVNW_QUIET_CURL='' __MVNW_QUIET_UNZIP='' __MVNW_QUIET_TAR=v # normalize http auth case "${MVNW_PASSWORD:+has-password}" in '') MVNW_USERNAME='' MVNW_PASSWORD='' ;; has-password) [ -n "${MVNW_USERNAME-}" ] || MVNW_USERNAME='' MVNW_PASSWORD='' ;; esac if [ -z "${MVNW_USERNAME-}" ] && command -v wget >/dev/null; then verbose "Found wget ... using wget" wget ${__MVNW_QUIET_WGET:+"$__MVNW_QUIET_WGET"} "$distributionUrl" -O "$TMP_DOWNLOAD_DIR/$distributionUrlName" || die "wget: Failed to fetch $distributionUrl" elif [ -z "${MVNW_USERNAME-}" ] && command -v curl >/dev/null; then verbose "Found curl ... using curl" curl ${__MVNW_QUIET_CURL:+"$__MVNW_QUIET_CURL"} -f -L -o "$TMP_DOWNLOAD_DIR/$distributionUrlName" "$distributionUrl" || die "curl: Failed to fetch $distributionUrl" elif set_java_home; then verbose "Falling back to use Java to download" javaSource="$TMP_DOWNLOAD_DIR/Downloader.java" targetZip="$TMP_DOWNLOAD_DIR/$distributionUrlName" cat >"$javaSource" <<-END public class Downloader extends java.net.Authenticator { protected java.net.PasswordAuthentication getPasswordAuthentication() { return new java.net.PasswordAuthentication( System.getenv( "MVNW_USERNAME" ), System.getenv( "MVNW_PASSWORD" ).toCharArray() ); } public static void main( String[] args ) throws Exception { setDefault( new Downloader() ); java.nio.file.Files.copy( java.net.URI.create( args[0] ).toURL().openStream(), java.nio.file.Paths.get( args[1] ).toAbsolutePath().normalize() ); } } END # For Cygwin/MinGW, switch paths to Windows format before running javac and java verbose " - Compiling Downloader.java ..." "$(native_path "$JAVACCMD")" "$(native_path "$javaSource")" || die "Failed to compile Downloader.java" verbose " - Running Downloader.java ..." "$(native_path "$JAVACMD")" -cp "$(native_path "$TMP_DOWNLOAD_DIR")" Downloader "$distributionUrl" "$(native_path "$targetZip")" fi # If specified, validate the SHA-256 sum of the Maven distribution zip file if [ -n "${distributionSha256Sum-}" ]; then distributionSha256Result=false if [ "$MVN_CMD" = mvnd.sh ]; then echo "Checksum validation is not supported for maven-mvnd." >&2 echo "Please disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 exit 1 elif command -v sha256sum >/dev/null; then if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | sha256sum -c >/dev/null 2>&1; then distributionSha256Result=true fi elif command -v shasum >/dev/null; then if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | shasum -a 256 -c >/dev/null 2>&1; then distributionSha256Result=true fi else echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available." >&2 echo "Please install either command, or disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 exit 1 fi if [ $distributionSha256Result = false ]; then echo "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised." >&2 echo "If you updated your Maven version, you need to update the specified distributionSha256Sum property." >&2 exit 1 fi fi # unzip and move if command -v unzip >/dev/null; then unzip ${__MVNW_QUIET_UNZIP:+"$__MVNW_QUIET_UNZIP"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -d "$TMP_DOWNLOAD_DIR" || die "failed to unzip" else tar xzf${__MVNW_QUIET_TAR:+"$__MVNW_QUIET_TAR"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -C "$TMP_DOWNLOAD_DIR" || die "failed to untar" fi printf %s\\n "$distributionUrl" >"$TMP_DOWNLOAD_DIR/$distributionUrlNameMain/mvnw.url" mv -- "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain" "$MAVEN_HOME" || [ -d "$MAVEN_HOME" ] || die "fail to move MAVEN_HOME" clean || : exec_maven "$@" ================================================ FILE: mvnw.cmd ================================================ <# : batch portion @REM ---------------------------------------------------------------------------- @REM Licensed to the Apache Software Foundation (ASF) under one @REM or more contributor license agreements. See the NOTICE file @REM distributed with this work for additional information @REM regarding copyright ownership. The ASF licenses this file @REM to you under the Apache License, Version 2.0 (the @REM "License"); you may not use this file except in compliance @REM with the License. You may obtain a copy of the License at @REM @REM http://www.apache.org/licenses/LICENSE-2.0 @REM @REM Unless required by applicable law or agreed to in writing, @REM software distributed under the License is distributed on an @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @REM KIND, either express or implied. See the License for the @REM specific language governing permissions and limitations @REM under the License. @REM ---------------------------------------------------------------------------- @REM ---------------------------------------------------------------------------- @REM Apache Maven Wrapper startup batch script, version 3.3.2 @REM @REM Optional ENV vars @REM MVNW_REPOURL - repo url base for downloading maven distribution @REM MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven @REM MVNW_VERBOSE - true: enable verbose log; others: silence the output @REM ---------------------------------------------------------------------------- @IF "%__MVNW_ARG0_NAME__%"=="" (SET __MVNW_ARG0_NAME__=%~nx0) @SET __MVNW_CMD__= @SET __MVNW_ERROR__= @SET __MVNW_PSMODULEP_SAVE=%PSModulePath% @SET PSModulePath= @FOR /F "usebackq tokens=1* delims==" %%A IN (`powershell -noprofile "& {$scriptDir='%~dp0'; $script='%__MVNW_ARG0_NAME__%'; icm -ScriptBlock ([Scriptblock]::Create((Get-Content -Raw '%~f0'))) -NoNewScope}"`) DO @( IF "%%A"=="MVN_CMD" (set __MVNW_CMD__=%%B) ELSE IF "%%B"=="" (echo %%A) ELSE (echo %%A=%%B) ) @SET PSModulePath=%__MVNW_PSMODULEP_SAVE% @SET __MVNW_PSMODULEP_SAVE= @SET __MVNW_ARG0_NAME__= @SET MVNW_USERNAME= @SET MVNW_PASSWORD= @IF NOT "%__MVNW_CMD__%"=="" (%__MVNW_CMD__% %*) @echo Cannot start maven from wrapper >&2 && exit /b 1 @GOTO :EOF : end batch / begin powershell #> $ErrorActionPreference = "Stop" if ($env:MVNW_VERBOSE -eq "true") { $VerbosePreference = "Continue" } # calculate distributionUrl, requires .mvn/wrapper/maven-wrapper.properties $distributionUrl = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionUrl if (!$distributionUrl) { Write-Error "cannot read distributionUrl property in $scriptDir/.mvn/wrapper/maven-wrapper.properties" } switch -wildcard -casesensitive ( $($distributionUrl -replace '^.*/','') ) { "maven-mvnd-*" { $USE_MVND = $true $distributionUrl = $distributionUrl -replace '-bin\.[^.]*$',"-windows-amd64.zip" $MVN_CMD = "mvnd.cmd" break } default { $USE_MVND = $false $MVN_CMD = $script -replace '^mvnw','mvn' break } } # apply MVNW_REPOURL and calculate MAVEN_HOME # maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ if ($env:MVNW_REPOURL) { $MVNW_REPO_PATTERN = if ($USE_MVND) { "/org/apache/maven/" } else { "/maven/mvnd/" } $distributionUrl = "$env:MVNW_REPOURL$MVNW_REPO_PATTERN$($distributionUrl -replace '^.*'+$MVNW_REPO_PATTERN,'')" } $distributionUrlName = $distributionUrl -replace '^.*/','' $distributionUrlNameMain = $distributionUrlName -replace '\.[^.]*$','' -replace '-bin$','' $MAVEN_HOME_PARENT = "$HOME/.m2/wrapper/dists/$distributionUrlNameMain" if ($env:MAVEN_USER_HOME) { $MAVEN_HOME_PARENT = "$env:MAVEN_USER_HOME/wrapper/dists/$distributionUrlNameMain" } $MAVEN_HOME_NAME = ([System.Security.Cryptography.MD5]::Create().ComputeHash([byte[]][char[]]$distributionUrl) | ForEach-Object {$_.ToString("x2")}) -join '' $MAVEN_HOME = "$MAVEN_HOME_PARENT/$MAVEN_HOME_NAME" if (Test-Path -Path "$MAVEN_HOME" -PathType Container) { Write-Verbose "found existing MAVEN_HOME at $MAVEN_HOME" Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" exit $? } if (! $distributionUrlNameMain -or ($distributionUrlName -eq $distributionUrlNameMain)) { Write-Error "distributionUrl is not valid, must end with *-bin.zip, but found $distributionUrl" } # prepare tmp dir $TMP_DOWNLOAD_DIR_HOLDER = New-TemporaryFile $TMP_DOWNLOAD_DIR = New-Item -Itemtype Directory -Path "$TMP_DOWNLOAD_DIR_HOLDER.dir" $TMP_DOWNLOAD_DIR_HOLDER.Delete() | Out-Null trap { if ($TMP_DOWNLOAD_DIR.Exists) { try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } } } New-Item -Itemtype Directory -Path "$MAVEN_HOME_PARENT" -Force | Out-Null # Download and Install Apache Maven Write-Verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." Write-Verbose "Downloading from: $distributionUrl" Write-Verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" $webclient = New-Object System.Net.WebClient if ($env:MVNW_USERNAME -and $env:MVNW_PASSWORD) { $webclient.Credentials = New-Object System.Net.NetworkCredential($env:MVNW_USERNAME, $env:MVNW_PASSWORD) } [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 $webclient.DownloadFile($distributionUrl, "$TMP_DOWNLOAD_DIR/$distributionUrlName") | Out-Null # If specified, validate the SHA-256 sum of the Maven distribution zip file $distributionSha256Sum = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionSha256Sum if ($distributionSha256Sum) { if ($USE_MVND) { Write-Error "Checksum validation is not supported for maven-mvnd. `nPlease disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." } Import-Module $PSHOME\Modules\Microsoft.PowerShell.Utility -Function Get-FileHash if ((Get-FileHash "$TMP_DOWNLOAD_DIR/$distributionUrlName" -Algorithm SHA256).Hash.ToLower() -ne $distributionSha256Sum) { Write-Error "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised. If you updated your Maven version, you need to update the specified distributionSha256Sum property." } } # unzip and move Expand-Archive "$TMP_DOWNLOAD_DIR/$distributionUrlName" -DestinationPath "$TMP_DOWNLOAD_DIR" | Out-Null Rename-Item -Path "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain" -NewName $MAVEN_HOME_NAME | Out-Null try { Move-Item -Path "$TMP_DOWNLOAD_DIR/$MAVEN_HOME_NAME" -Destination $MAVEN_HOME_PARENT | Out-Null } catch { if (! (Test-Path -Path "$MAVEN_HOME" -PathType Container)) { Write-Error "fail to move MAVEN_HOME" } } finally { try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } } Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" ================================================ FILE: package.json ================================================ { "devDependencies": { "@commitlint/cli": "^20.4.2", "@commitlint/config-conventional": "^20.4.2", "@commitlint/lint": "^20.4.2", "@commitlint/load": "^20.4.0", "danger": "^13.0.7", "husky": "^9.1.7" }, "scripts": { "prepare": "husky" } } ================================================ FILE: pom.xml ================================================ 4.0.0 org.neo4j neo4j-connector-apache-spark_parent 5.4.3-SNAPSHOT pom neo4j-connector-apache-spark_parent Neo4j Connector for Apache Spark using the binary Bolt Driver https://github.com/neo4j/neo4j-spark-connector Neo4j, Inc. https://neo4j.com/ Apache License, Version 2.0 https://www.apache.org/licenses/LICENSE-2.0.txt manual team-connectors Connectors Team Neo4j https://neo4j.com common test-support spark-3 https://github.com/neo4j/neo4j-spark-connector 2026.02.5 1.3.3 1.28.0 3.20.0 1.0.0-rc2 2022.11.0 4.4.22 1.8 3.4.3.Final 4.13.2 4.6 1.7.12 /license/neo4j_apache_v2/notice.txt 3.5.0 3.15.0 3.10.0 3.1.4 3.6.2 3.5.5 3.1.4 3.5.0 3.5.0 3.21.0 3.4.0 3.5.5 false 4.4.20 4.1.132.Final UTF-8 4.9.10 2.12 2.12.20 2.0.17 4.0.0 3.5.8 3.4.0 2.0.4 io.netty netty-bom ${netty-bom.version} pom import org.testcontainers testcontainers-bom ${testcontainers.version} pom import junit junit ${junit.version} org.apache.commons commons-compress ${commons-compress.version} org.apache.commons commons-lang3 ${commons-lang3.version} org.apache.spark spark-sql_${scala.binary.version} ${spark.version} org.hamcrest hamcrest 3.0 org.jboss.logging jboss-logging ${jboss-logging.version} org.neo4j caniuse-core ${caniuse.version} org.neo4j caniuse-neo4j-detection ${caniuse.version} org.neo4j neo4j-cypher-dsl ${cypherdsl.version} org.neo4j.connectors commons-authn-keycloak ${connectors-commons.version} org.neo4j.connectors commons-authn-provided ${connectors-commons.version} org.neo4j.connectors commons-authn-spi ${connectors-commons.version} org.neo4j.connectors commons-reauth-driver ${connectors-commons.version} org.neo4j.driver neo4j-java-driver-slim ${driver.version} org.scala-lang scala-library ${scala.version} org.scala-lang scala-reflect ${scala.version} org.scalatest scalatest_${scala.binary.version} 3.2.20 org.scalatestplus junit-4-13_${scala.binary.version} 3.2.20.0 org.slf4j slf4j-api ${slf4j-api.version} org.apache.spark spark-core_${scala.binary.version} ${spark.version} provided org.apache.xbean xbean-asm6-shaded org.codehaus.mojo versions-maven-plugin 2.21.0 false org.apache.maven.plugins maven-clean-plugin ${maven-clean-plugin.version} org.apache.maven.plugins maven-deploy-plugin ${maven-deploy-plugin.version} org.apache.maven.plugins maven-install-plugin ${maven-install-plugin.version} org.apache.maven.plugins maven-site-plugin ${maven-site-plugin.version} org.apache.maven.plugins maven-jar-plugin ${maven-jar-plugin.version} true org.apache.maven.plugins maven-resources-plugin ${maven-resources-plugin.version} ${project.build.sourceEncoding} org.neo4j.build.plugins licensing-maven-plugin ${licensing-maven-plugin.version} true true true ${licensing.prepend.text} ^((org.neo4j){1}|(org.neo4j.connectors){1}|(org.neo4j.connectors.kafka){1}|(org.neo4j.driver){1})$ compile,runtime org.neo4j.build resources ${build-resources.version} list-all-licenses check compile licensing/licensing-requirements-base.xml ${project.artifactId}-${project.version}-NOTICE.txt ${project.build.directory}/../NOTICE.txt /licensing/list-prefix.txt ${project.artifactId}-${project.version}-LICENSES.txt ${project.build.directory}/../LICENSES.txt org.apache.maven.plugins maven-source-plugin ${maven-source-plugin.version} attach-sources jar org.apache.maven.plugins maven-enforcer-plugin ${maven-enforcer-plugin.version} enforce enforce validate [3.6.3,) 8 org.apache.maven.plugins maven-compiler-plugin ${maven-compiler-plugin.version} ${java.version} ${java.version} org.apache.maven.plugins maven-failsafe-plugin ${maven-failsafe-plugin.version} **/*IT.* **/*TSE.* 1 false ${surefire.jvm.args} false integration-test verify integration-test org.apache.maven.plugins maven-dependency-plugin ${maven-dependency-plugin.version} com.mycila license-maven-plugin ${license-maven-plugin.version} true
/license/neo4j_apache_v2/header.txt
src/**/*.scala src/**/*.java
SLASHSTAR_STYLE
org.neo4j.build resources ${build-resources.version} check-licenses check compile
org.apache.maven.plugins maven-surefire-plugin ${maven-surefire-plugin.version} ${surefire.jvm.args} net.alchim31.maven scala-maven-plugin ${scala-maven-plugin.version} ${scala.version} ${scala.binary.version} -target:jvm-1.8 -Xms64m -Xmx1024m scala-compile add-source compile testCompile doc-jar process-resources scala-test-compile testCompile test-compile maven-assembly-plugin 3.8.0 jar-with-dependencies ${project.artifactId}-${project.version} false single package com.github.ekryd.sortpom sortpom-maven-plugin ${sortpom-maven-plugin.version} ${project.build.sourceEncoding} false schemaLocation 4 true scope,groupId,artifactId false false verify validate STOP com.diffplug.spotless spotless-maven-plugin ${spotless-maven-plugin.version} src/main/scala/**/*.scala src/test/scala/**/*.scala 3.8.4-RC3 scalafmt/scalafmt.conf 2.12 org.neo4j.build resources ${build-resources.version} check compile
org.apache.maven.plugins maven-enforcer-plugin
scala-2.13 scala-2.13 2.13 2.13.16 has-sources false src org.neo4j.build.plugins licensing-maven-plugin com.mycila license-maven-plugin org.apache.maven.plugins maven-source-plugin enable-jdk11-plugins [11,) -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false com.github.ekryd.sortpom sortpom-maven-plugin enable-jdk17-plugins [17,) com.diffplug.spotless spotless-maven-plugin
================================================ FILE: scripts/python/requirements.txt ================================================ pyspark==3.5.5 testcontainers[neo4j] six tzlocal==2.1 ================================================ FILE: scripts/python/test_spark.py ================================================ #!/usr/bin/env python3 import unittest import sys import datetime from tzlocal import get_localzone from testcontainers.neo4j import Neo4jContainer from pyspark.sql import SparkSession from neo4j import Driver class SparkTest(unittest.TestCase): neo4j_driver: Driver = None neo4j_container: Neo4jContainer = None spark: SparkSession = None def tearDown(self): with self.neo4j_driver.session(database="system") as session: session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds").consume() def init_test(self, query, parameters=None): with self.neo4j_driver.session() as session: session.run(query, parameters).consume() return ( self.spark.read.format("org.neo4j.spark.DataSource") .option("url", self.neo4j_container.get_connection_url()) .option("authentication.type", "basic") .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", "password") .option("labels", "Person") .load() ) def test_string(self): name = "Foobar" df = self.init_test("CREATE (p:Person {name: '" + name + "'})") assert name == df.select("name").collect()[0].name def test_int(self): age = 32 df = self.init_test("CREATE (p:Person {age: " + str(age) + "})") assert age == df.select("age").collect()[0].age def test_double(self): score = 32.3 df = self.init_test("CREATE (p:Person {score: " + str(score) + "})") assert score == df.select("score").collect()[0].score def test_boolean(self): df = self.init_test("CREATE (p:Person {boolean: true})") assert True == df.select("boolean").collect()[0].boolean def test_time(self): time = datetime.time(12, 23, 0, 0, get_localzone()) df = self.init_test( "CREATE (p:Person {myTime: time({hour:12, minute: 23, second: 0})})" ) timeResult = df.select("myTime").collect()[0].myTime assert "offset-time" == timeResult.type # .replace used in case of UTC timezone because of https://stackoverflow.com/a/42777551/1409772 assert str(time).replace("+00:00", "Z") == timeResult.value.split("+")[0] def test_datetime(self): dtString = "2015-06-24T12:50:35" df = self.init_test( "CREATE (p:Person {datetime: datetime('" + dtString + "')})" ) dt = datetime.datetime(2015, 6, 24, 12, 50, 35, 0) dtResult = df.select("datetime").collect()[0].datetime assert dt == dtResult def test_date(self): df = self.init_test("CREATE (p:Person {born: date('2009-10-10')})") dt = datetime.date(2009, 10, 10) dtResult = df.select("born").collect()[0].born assert dt == dtResult def test_point(self): df = self.init_test("CREATE (p:Person {location: point({x: 12.12, y: 13.13})})") pointResult = df.select("location").collect()[0].location assert "point-2d" == pointResult[0] assert 7203 == pointResult[1] assert 12.12 == pointResult[2] assert 13.13 == pointResult[3] def test_point3d(self): df = self.init_test( "CREATE (p:Person {location: point({x: 12.12, y: 13.13, z: 1})})" ) pointResult = df.select("location").collect()[0].location assert "point-3d" == pointResult[0] assert 9157 == pointResult[1] assert 12.12 == pointResult[2] assert 13.13 == pointResult[3] assert 1.0 == pointResult[4] def test_geopoint(self): df = self.init_test( "CREATE (p:Person {location: point({longitude: 12.12, latitude: 13.13})})" ) pointResult = df.select("location").collect()[0].location assert "point-2d" == pointResult[0] assert 4326 == pointResult[1] assert 12.12 == pointResult[2] assert 13.13 == pointResult[3] def test_duration(self): df = self.init_test( "CREATE (p:Person {range: duration({days: 14, hours:16, minutes: 12})})" ) durationResult = df.select("range").collect()[0].range assert "duration" == durationResult[0] assert 0 == durationResult[1] assert 14 == durationResult[2] assert 58320 == durationResult[3] assert 0 == durationResult[4] def test_binary(self): byte_array = b"binaries are byte arrays" df = self.init_test("CREATE (p:Person {bin: $bytes})", {"bytes": byte_array}) assert byte_array == df.select("bin").collect()[0].bin def test_string_array(self): df = self.init_test("CREATE (p:Person {names: ['John', 'Doe']})") result = df.select("names").collect()[0].names assert "John" == result[0] assert "Doe" == result[1] def test_int_array(self): df = self.init_test("CREATE (p:Person {ages: [24, 56]})") result = df.select("ages").collect()[0].ages assert 24 == result[0] assert 56 == result[1] def test_double_array(self): df = self.init_test("CREATE (p:Person {scores: [24.11, 56.11]})") result = df.select("scores").collect()[0].scores assert 24.11 == result[0] assert 56.11 == result[1] def test_boolean_array(self): df = self.init_test("CREATE (p:Person {field: [true, false]})") result = df.select("field").collect()[0].field assert True == result[0] assert False == result[1] def test_time_array(self): df = self.init_test( "CREATE (p:Person {result: [time({hour:11, minute: 23, second: 0}), time({hour:12, minute: 23, second: 0})]})" ) timeResult = df.select("result").collect()[0].result # .replace used in case of UTC timezone because of https://stackoverflow.com/a/42777551/1409772 assert "offset-time" == timeResult[0].type assert ( str(datetime.time(11, 23, 0, 0, get_localzone())).replace("+00:00", "Z") == timeResult[0].value.split("+")[0] ) # .replace used in case of UTC timezone because of https://stackoverflow.com/a/42777551/1409772 assert "offset-time" == timeResult[1].type assert ( str(datetime.time(12, 23, 0, 0, get_localzone())).replace("+00:00", "Z") == timeResult[1].value.split("+")[0] ) def test_datetime_array(self): df = self.init_test( "CREATE (p:Person {result: [datetime('2007-12-03T10:15:30'), datetime('2008-12-03T10:15:30')]})" ) dt1 = datetime.datetime(2007, 12, 3, 10, 15, 30, 0) dt2 = datetime.datetime(2008, 12, 3, 10, 15, 30, 0) dtResult = df.select("result").collect()[0].result assert dt1 == dtResult[0] assert dt2 == dtResult[1] def test_date_array(self): df = self.init_test( "CREATE (p:Person {result: [date('2009-10-10'), date('2008-10-10')]})" ) dt1 = datetime.date(2009, 10, 10) dt2 = datetime.date(2008, 10, 10) dtResult = df.select("result").collect()[0].result assert dt1 == dtResult[0] assert dt2 == dtResult[1] def test_point_array(self): df = self.init_test( "CREATE (p:Person {location: [point({x: 12.12, y: 13.13}), point({x: 13.13, y: 14.14})]})" ) pointResult = df.select("location").collect()[0].location assert "point-2d" == pointResult[0][0] assert 7203 == pointResult[0][1] assert 12.12 == pointResult[0][2] assert 13.13 == pointResult[0][3] assert "point-2d" == pointResult[1][0] assert 7203 == pointResult[1][1] assert 13.13 == pointResult[1][2] assert 14.14 == pointResult[1][3] def test_point3d_array(self): df = self.init_test( "CREATE (p:Person {location: [point({x: 12.12, y: 13.13, z: 1}), point({x: 14.14, y: 15.15, z: 1})]})" ) pointResult = df.select("location").collect()[0].location assert "point-3d" == pointResult[0][0] assert 9157 == pointResult[0][1] assert 12.12 == pointResult[0][2] assert 13.13 == pointResult[0][3] assert 1.0 == pointResult[0][4] assert "point-3d" == pointResult[1][0] assert 9157 == pointResult[1][1] assert 14.14 == pointResult[1][2] assert 15.15 == pointResult[1][3] assert 1.0 == pointResult[1][4] def test_geopoint_array(self): df = self.init_test( "CREATE (p:Person {location: [point({longitude: 12.12, latitude: 13.13}), point({longitude: 14.14, latitude: 15.15})]})" ) pointResult = df.select("location").collect()[0].location assert "point-2d" == pointResult[0][0] assert 4326 == pointResult[0][1] assert 12.12 == pointResult[0][2] assert 13.13 == pointResult[0][3] assert "point-2d" == pointResult[1][0] assert 4326 == pointResult[1][1] assert 14.14 == pointResult[1][2] assert 15.15 == pointResult[1][3] def test_duration_array(self): df = self.init_test( "CREATE (p:Person {range: [duration({days: 14, hours:16, minutes: 12}), duration({days: 15, hours:16, minutes: 12})]})" ) durationResult = df.select("range").collect()[0].range assert "duration" == durationResult[0][0] assert 0 == durationResult[0][1] assert 14 == durationResult[0][2] assert 58320 == durationResult[0][3] assert 0 == durationResult[0][4] assert "duration" == durationResult[1][0] assert 0 == durationResult[1][1] assert 15 == durationResult[1][2] assert 58320 == durationResult[1][3] assert 0 == durationResult[1][4] def test_unexisting_property(self): ( self.spark.read.format("org.neo4j.spark.DataSource") .option("url", self.neo4j_container.get_connection_url()) .option("authentication.type", "basic") .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", "password") .option("relationship.properties", None) .option("relationship", "FOO") .option("relationship.source.labels", ":Foo") .option("relationship.target.labels", ":Bar") .load() ) # In this case we just test that the job has been executed without any exception def test_gds(self): with self.neo4j_driver.session() as session: session.run( """ CREATE (home:Page {name:'Home'}), (about:Page {name:'About'}), (product:Page {name:'Product'}), (links:Page {name:'Links'}), (a:Page {name:'Site A'}), (b:Page {name:'Site B'}), (c:Page {name:'Site C'}), (d:Page {name:'Site D'}), (home)-[:LINKS {weight: 0.2}]->(about), (home)-[:LINKS {weight: 0.2}]->(links), (home)-[:LINKS {weight: 0.6}]->(product), (about)-[:LINKS {weight: 1.0}]->(home), (product)-[:LINKS {weight: 1.0}]->(home), (a)-[:LINKS {weight: 1.0}]->(home), (b)-[:LINKS {weight: 1.0}]->(home), (c)-[:LINKS {weight: 1.0}]->(home), (d)-[:LINKS {weight: 1.0}]->(home), (links)-[:LINKS {weight: 0.8}]->(home), (links)-[:LINKS {weight: 0.05}]->(a), (links)-[:LINKS {weight: 0.05}]->(b), (links)-[:LINKS {weight: 0.05}]->(c), (links)-[:LINKS {weight: 0.05}]->(d); """ ) ( self.spark.read.format("org.neo4j.spark.DataSource") .option("url", self.neo4j_container.get_connection_url()) .option("authentication.type", "basic") .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", "password") .option("gds", "gds.graph.project") .option("gds.graphName", "myGraph") .option("gds.nodeProjection", "Page") .option("gds.relationshipProjection", "LINKS") .option("gds.configuration.relationshipProperties", "weight") .load() .show(truncate=False) ) df = ( self.spark.read.format("org.neo4j.spark.DataSource") .option("url", self.neo4j_container.get_connection_url()) .option("authentication.type", "basic") .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", "password") .option("gds", "gds.pageRank.stream") .option("gds.graphName", "myGraph") .option("gds.configuration.concurrency", "2") .load() ) assert 8 == df.count() if len(sys.argv) != 3: print("Wrong arguments count") print(sys.argv) sys.exit(1) neo4j_image = str(sys.argv.pop()) connector_jar = str(sys.argv.pop()) current_time_zone = get_localzone().zone if __name__ == "__main__": with ( Neo4jContainer(neo4j_image) .with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") .with_env("NEO4J_db_temporal_timezone", current_time_zone) .with_env("NEO4JLABS_PLUGINS", '["graph-data-science"]') ) as neo4j_container: with neo4j_container.get_driver() as neo4j_driver: SparkTest.spark = ( SparkSession.builder.appName("Neo4jConnectorTests") .master("local[*]") .config("spark.jars", connector_jar) .config("spark.driver.host", "127.0.0.1") .getOrCreate() ) SparkTest.neo4j_driver = neo4j_driver SparkTest.neo4j_container = neo4j_container unittest.main() SparkTest.spark.close() ================================================ FILE: scripts/release/upload_to_spark_packages.sh ================================================ #!/bin/bash set -eEuo pipefail if [[ $# -lt 5 ]] ; then echo "Usage ./upload_to_spark_packages.sh " exit 1 fi USER=$1 TOKEN=$2 GIT_HASH=$3 VERSION=$4 PATH_TO_PACKAGE_FILE=$5 # License codes expected: # 0 - Apache 2.0 # 1 - BSD 3-Clause # 2 - BSD 2-Clause # 3 - GPL-2.0 # 4 - GPL-3.0 # 5 - LGPL-2.1 # 6 - LGPL-3.0 # 7 - MIT # 8 - MPL-2.0 # 9 - EPL-1.0 # 10 - Other license LICENSE="0" curl -X POST 'https://spark-packages.org/api/submit-release' \ -u "$USER:$TOKEN" \ -F "git_commit_sha1=$GIT_HASH" \ -F "version=$VERSION" \ -F "license_id=$LICENSE" \ -F "name=neo4j/neo4j-spark-connector" \ -F "artifact_zip=@$PATH_TO_PACKAGE_FILE;type=application/zip" ================================================ FILE: spark-3/LICENSES.txt ================================================ This file contains the full license text of the included third party libraries. For an overview of the licenses see the NOTICE.txt file. ------------------------------------------------------------------------------ Apache Software License, Version 2.0 JetBrains Java Annotations Kotlin Stdlib Netty/Buffer Netty/Codec Netty/Common Netty/Handler Netty/Resolver Netty/TomcatNative [OpenSSL - Classes] Netty/Transport Netty/Transport/Native/Unix/Common Non-Blocking Reactive Foundation for the JVM org.apiguardian:apiguardian-api ------------------------------------------------------------------------------ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ------------------------------------------------------------------------------ MIT License SLF4J API Module ------------------------------------------------------------------------------ The MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ------------------------------------------------------------------------------ MIT No Attribution License reactive-streams ------------------------------------------------------------------------------ MIT No Attribution Copyright Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: spark-3/NOTICE.txt ================================================ Copyright (c) "Neo4j" Neo4j Sweden AB [https://neo4j.com] This file is part of Neo4j. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Full license texts are found in LICENSES.txt. Third-party licenses -------------------- Apache Software License, Version 2.0 JetBrains Java Annotations Kotlin Stdlib Netty/Buffer Netty/Codec Netty/Common Netty/Handler Netty/Resolver Netty/TomcatNative [OpenSSL - Classes] Netty/Transport Netty/Transport/Native/Unix/Common Non-Blocking Reactive Foundation for the JVM org.apiguardian:apiguardian-api MIT License SLF4J API Module MIT No Attribution License reactive-streams ================================================ FILE: spark-3/pom.xml ================================================ 4.0.0 org.neo4j neo4j-connector-apache-spark_parent 5.4.3-SNAPSHOT neo4j-connector-apache-spark jar neo4j-connector-apache-spark-${spark.version} Spark ${spark.version} for Neo4j Connector for Apache Spark using the binary Bolt Driver neo4j-spark-connector neo4j ${spark-packages.artifactId}-${spark-packages.version} org.neo4j neo4j-connector-apache-spark_common ${project.version} org.neo4j.driver neo4j-java-driver-slim org.apache.spark spark-core_${scala.binary.version} provided org.apache.spark spark-sql_${scala.binary.version} provided org.scala-lang scala-library provided org.scala-lang scala-reflect provided org.neo4j neo4j-connector-apache-spark_test-support ${project.version} test org.neo4j.connectors commons-authn-keycloak test org.scalatest scalatest_${scala.binary.version} test pl.pragmatists JUnitParams 1.1.1 test true src/main/resources net.alchim31.maven scala-maven-plugin org.apache.maven.plugins maven-failsafe-plugin org.apache.maven.plugins maven-surefire-plugin org.apache.maven.plugins maven-resources-plugin ${maven-resources-plugin.version} ISO-8859-1 ${project.build.sourceEncoding} maven-assembly-plugin 3.8.0 bin single package src/main/assemblies/spark-packages-assembly.xml false ${spark-packages.packageName} ================================================ FILE: spark-3/src/jreleaser/assemblers/zip/README.txt.tpl ================================================ Neo4j Connector for Apache Spark {{projectVersion}} This archive contains release materials for the Neo4j Connector for Apache Spark. **Make sure you use the correct JAR, depending on the Scala version in use in your spark environment** Source Code & Release Notes: https://github.com/neo4j/neo4j-spark-connector/releases Documentation: https://neo4j.com/docs/spark/current/ Need Support? Have a Question? If you are a Neo4j Enterprise customer, please have a look at the enterprise knowledge base, or contact Neo4j Professional Support through: https://support.neo4j.com/hc/en-us Otherwise, please consider posting a question on the Neo4j Community site: https://community.neo4j.com/ ================================================ FILE: spark-3/src/main/assemblies/spark-packages-assembly.xml ================================================ bin zip false ${project.build.directory}/${project.artifactId}-${project.version}.jar / ${spark-packages.packageName}.jar true src/main/distributions/spark-packages.pom / ${spark-packages.packageName}.pom ================================================ FILE: spark-3/src/main/distributions/spark-packages.pom ================================================ 4.0.0 ${spark-packages.groupId} ${spark-packages.artifactId} ${spark-packages.version} ================================================ FILE: spark-3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ================================================ org.neo4j.spark.DataSource ================================================ FILE: spark-3/src/main/resources/neo4j-spark-connector.properties ================================================ version=${project.version} ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/DataSource.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDetector import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jDriverOptions import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.ValidateConnection import org.neo4j.spark.util.ValidateSparkMinVersion import org.neo4j.spark.util.Validations import java.util.UUID class DataSource extends TableProvider with DataSourceRegister { Validations.validate(ValidateSparkMinVersion("3.3.0")) private val jobId: String = UUID.randomUUID().toString private var schema: StructType = _ private var neo4jOptions: Neo4jOptions = _ private var neo4j: Neo4j = _ override def supportsExternalMetadata(): Boolean = true override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = { if (schema == null) { val neo4jOpts = getNeo4jOptions(caseInsensitiveStringMap) Validations.validate(ValidateConnection(neo4jOpts, jobId)) val neo4j = getNeo4jInfo(neo4jOpts.connection) schema = Neo4jUtil.callSchemaService( neo4j, neo4jOpts, jobId, Array.empty[Filter], { schemaService => schemaService.struct() } ) } schema } private def getNeo4jInfo(options: Neo4jDriverOptions): Neo4j = { if (neo4j == null) { val driver = new DriverCache(options).getOrCreate() neo4j = Neo4jDetector.INSTANCE.detect(driver) } neo4j } private def getNeo4jOptions(caseInsensitiveStringMap: CaseInsensitiveStringMap) = { if (neo4jOptions == null) { neo4jOptions = Neo4jOptions.fromSession(SparkSession.getActiveSession, caseInsensitiveStringMap.asCaseSensitiveMap()) } neo4jOptions } override def getTable( structType: StructType, transforms: Array[Transform], map: java.util.Map[String, String] ): Table = { val caseInsensitiveStringMapNeo4jOptions = new CaseInsensitiveStringMap(map) val schema = if (structType != null) { structType } else { inferSchema(caseInsensitiveStringMapNeo4jOptions) } val neo4jOpts = getNeo4jOptions(caseInsensitiveStringMapNeo4jOptions) val neo4jInfo = getNeo4jInfo(neo4jOpts.connection) new Neo4jTable(neo4jInfo, schema, neo4jOpts, jobId) } override def shortName(): String = "neo4j" } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/Neo4jTable.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.SupportsWrite import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.catalog.TableCapability import org.apache.spark.sql.connector.write.LogicalWriteInfo import org.apache.spark.sql.connector.write.WriteBuilder import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.neo4j.caniuse.Neo4j import org.neo4j.driver.AccessMode import org.neo4j.spark.reader.Neo4jScanBuilder import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.ValidateRead import org.neo4j.spark.util.Validations import org.neo4j.spark.writer.Neo4jWriterBuilder import scala.collection.JavaConverters._ class Neo4jTable(neo4j: Neo4j, schema: StructType, neo4jOptions: Neo4jOptions, jobId: String) extends Table with SupportsRead with SupportsWrite with Logging { override def name(): String = neo4jOptions.getTableName override def schema(): StructType = schema override def capabilities(): java.util.Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA, TableCapability.OVERWRITE_BY_FILTER, TableCapability.OVERWRITE_DYNAMIC, TableCapability.STREAMING_WRITE, TableCapability.MICRO_BATCH_READ ).asJava override def newScanBuilder(options: CaseInsensitiveStringMap): Neo4jScanBuilder = { Validations.validate(ValidateRead(neo4j, neo4jOptions, jobId)) new Neo4jScanBuilder(neo4j, neo4jOptions, jobId, schema()) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { val mapOptions = neo4jOptions.asMap() mapOptions.put(Neo4jOptions.ACCESS_MODE, AccessMode.WRITE.toString) val writeNeo4jOptions = new Neo4jOptions(mapOptions) new Neo4jWriterBuilder(neo4j, info.queryId(), info.schema(), SaveMode.Append, writeNeo4jOptions) } } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jPartitionReader.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.reader import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.util.Neo4jOptions class Neo4jPartitionReader( private val neo4j: Neo4j, private val options: Neo4jOptions, private val filters: Array[Filter], private val schema: StructType, private val jobId: String, private val partitionSkipLimit: PartitionPagination, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val requiredColumns: StructType, private val aggregateColumns: Array[AggregateFunc] ) extends BasePartitionReader( neo4j, options, filters, schema, jobId, partitionSkipLimit, scriptResult, requiredColumns, aggregateColumns ) with PartitionReader[InternalRow] ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jPartitionReaderFactory.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.reader import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.util.Neo4jOptions class Neo4jPartitionReaderFactory( private val neo4j: Neo4j, private val neo4jOptions: Neo4jOptions, private val filters: Array[Filter], private val schema: StructType, private val jobId: String, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val requiredColumns: StructType, private val aggregateColumns: Array[AggregateFunc] ) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = new Neo4jPartitionReader( neo4j, neo4jOptions, filters, schema, jobId, partition.asInstanceOf[Neo4jPartition].partitionSkipLimit, scriptResult, requiredColumns, aggregateColumns ) } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScan.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.reader import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.Batch import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.connector.read.streaming.MicroBatchStream import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.config.TopN import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.streaming.Neo4jMicroBatchReader import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import org.neo4j.spark.util.ValidateReadNotStreaming import org.neo4j.spark.util.ValidateReadStreaming import org.neo4j.spark.util.Validations case class Neo4jPartition(partitionSkipLimit: PartitionPagination) extends InputPartition class Neo4jScan( neo4j: Neo4j, neo4jOptions: Neo4jOptions, jobId: String, schema: StructType, filters: Array[Filter], requiredColumns: StructType, aggregateColumns: Array[AggregateFunc], topN: Option[TopN] ) extends Scan with Batch { override def toBatch: Batch = this var scriptResult: java.util.List[java.util.Map[String, AnyRef]] = _ private def createPartitions() = { Validations.validate(ValidateReadNotStreaming(neo4jOptions, jobId)) // we get the skip/limit for each partition and execute the "script" val (partitionSkipLimitList, scriptResult) = Neo4jUtil.callSchemaService( neo4j, neo4jOptions, jobId, filters, { schemaService => (schemaService.skipLimitFromPartition(topN), schemaService.execute(neo4jOptions.script)) } ) // we generate a partition for each element this.scriptResult = scriptResult partitionSkipLimitList .map(partitionSkipLimit => Neo4jPartition(partitionSkipLimit)) } override def planInputPartitions(): Array[InputPartition] = { val neo4jPartitions: Seq[Neo4jPartition] = createPartitions() neo4jPartitions.toArray } override def createReaderFactory(): PartitionReaderFactory = { new Neo4jPartitionReaderFactory( neo4j, neo4jOptions, filters, schema, jobId, scriptResult, requiredColumns, aggregateColumns ) } override def readSchema(): StructType = schema override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { Validations.validate(ValidateReadStreaming(neo4j, neo4jOptions, jobId)) new Neo4jMicroBatchReader(neo4j, schema, neo4jOptions, jobId, aggregateColumns) } } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScanBuilder.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.reader import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.connector.read.SupportsPushDownAggregates import org.apache.spark.sql.connector.read.SupportsPushDownLimit import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns import org.apache.spark.sql.connector.read.SupportsPushDownTopN import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.config.TopN import org.neo4j.spark.util.Neo4jImplicits.AggregationImplicit import org.neo4j.spark.util.Neo4jImplicits.CypherImplicits import org.neo4j.spark.util.Neo4jImplicits.PredicateImplicit import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.QueryType class Neo4jScanBuilder(neo4jInfo: Neo4j, neo4jOptions: Neo4jOptions, jobId: String, schema: StructType) extends SupportsPushDownV2Filters with SupportsPushDownAggregates with SupportsPushDownRequiredColumns with SupportsPushDownLimit with SupportsPushDownTopN with Logging { private var predicates: Array[Predicate] = Array.empty private var requiredSchema: StructType = schema private var requiredColumns: StructType = new StructType() private var aggregateColumns: Array[AggregateFunc] = Array.empty[AggregateFunc] private var limit: Option[Int] = None private var topN: Option[TopN] = None override def build(): Scan = { new Neo4jScan( neo4jInfo, neo4jOptions, jobId, requiredSchema, predicates.flatMap(_.toFilter(neo4jOptions)), requiredColumns, aggregateColumns, topN.orElse(limit.map((limit: Int) => TopN(limit))) ) } override def pushPredicates(predicatesArray: Array[Predicate]): Array[Predicate] = { if (neo4jOptions.pushdownFiltersEnabled) { predicates = predicatesArray } predicatesArray } override def pushedPredicates(): Array[Predicate] = predicates override def pruneColumns(newSchema: StructType): Unit = { if (!neo4jOptions.pushdownColumnsEnabled || neo4jOptions.relationshipMetadata.nodeMap) { new StructType() } else { requiredColumns = StructType(requiredSchema.filter(sf => newSchema.contains(sf))) } } override def pushAggregation(aggregation: Aggregation): Boolean = { if ( !neo4jOptions.pushdownAggregateEnabled || aggregation.aggregateExpressions().isEmpty || neo4jOptions.query.queryType == QueryType.QUERY ) { return false } aggregateColumns = aggregation.aggregateExpressions() val groupByColumns: Set[String] = aggregation.groupByCols() .map(_.describe().unquote()) .toSet requiredColumns = StructType(requiredSchema.filter(field => groupByColumns.contains(field.name))) aggregateColumns.foreach(af => { val fields = try { af.children() .toSet[Expression] .map(_.describe()) .map(_.unquote()) } catch { // for making it compatible with Spark 3.2 case noSuchMethodException: NoSuchMethodError => Set(af.describe().unquote()) } val dt = if (fields.nonEmpty) { requiredSchema.filter(field => fields.contains(field.name)) .map(_.dataType) .toSet .headOption .getOrElse(LongType) } else { LongType } requiredColumns = requiredColumns.add(StructField(af.toString, dt)) }) requiredSchema = requiredColumns true } override def pushLimit(pushedLimit: Int): Boolean = { if (!neo4jOptions.pushdownLimitEnabled) { return false } if (neo4jOptions.partitions > 1) { logWarning( s"""Disabling pushed down limit support since it conflicts with partitioning. |Set the `${Neo4jOptions.PARTITIONS}` parameter value to 1 | or set `${Neo4jOptions.PUSHDOWN_LIMIT_ENABLED}` to false to remove this warning. |""".stripMargin ) return false } if (pushedLimit <= 0) { logWarning(s"Ignoring negative pushed down limit $pushedLimit.") return false } limit = Some(pushedLimit) true } override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { if (!neo4jOptions.pushdownTopNEnabled) { return false } if (limit > 0 && neo4jOptions.partitions > 1) { logWarning("disabling pushed down top N support since it conflicts with partitioning." + "set the partition count to 1 or disable the pushdown limit support to remove this warning") return false } topN = Some(TopN(limit, orders)) true } // otherwise doesn't compile in Scala 2.12 override def isPartiallyPushed: Boolean = true } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jMicroBatchReader.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.connector.read.streaming.MicroBatchStream import org.apache.spark.sql.connector.read.streaming.Offset import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.GreaterThan import org.apache.spark.sql.sources.LessThanOrEqual import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util._ class Neo4jMicroBatchReader( private val neo4j: Neo4j, private val schema: StructType, private val neo4jOptions: Neo4jOptions, private val jobId: String, private val aggregateColumns: Array[AggregateFunc] ) extends MicroBatchStream with Logging { private val driverCache = new DriverCache(neo4jOptions.connection) private lazy val scriptResult = { val schemaService = new SchemaService(neo4j, neo4jOptions, driverCache) schemaService.createOptimizations(schema) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() scriptResult } private var filters: Array[Filter] = Array.empty[Filter] override def deserializeOffset(json: String): Offset = Neo4jOffset(json.toLong) override def commit(end: Offset): Unit = {} override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { logDebug(s"start and end offset: $start - $end") val prop = Neo4jUtil.getStreamingPropertyName(neo4jOptions) this.filters = Array( GreaterThan(prop, start.asInstanceOf[Neo4jOffset].offset), LessThanOrEqual(prop, end.asInstanceOf[Neo4jOffset].offset) ) val partitions = Neo4jUtil.callSchemaService( neo4j, neo4jOptions, jobId, filters, { schemaService => schemaService.skipLimitFromPartition(None) } ) partitions .map(p => Neo4jStreamingPartition(p, filters)) .toArray } override def stop(): Unit = { driverCache.close() } override def latestOffset(): Offset = { val offsetValue = Neo4jUtil.callSchemaService[Option[Long]]( neo4j, neo4jOptions, jobId, filters, { schemaService => try { schemaService.lastOffset() } catch { case _: Throwable => null } } ) offsetValue.map(value => Neo4jOffset(value)).orNull } override def initialOffset(): Offset = Neo4jOffset(neo4jOptions.streamingOptions.from.value()) override def createReaderFactory(): PartitionReaderFactory = { new Neo4jStreamingPartitionReaderFactory( neo4j, neo4jOptions, schema, jobId, scriptResult, aggregateColumns ) } } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jOffset.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.connector.read.streaming.Offset case class Neo4jOffset(offset: Long) extends Offset { override def json(): String = offset.toString } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingDataWriterFactory.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.writer.Neo4jDataWriter class Neo4jStreamingDataWriterFactory( neo4j: Neo4j, jobId: String, schema: StructType, saveMode: SaveMode, options: Neo4jOptions, scriptResult: java.util.List[java.util.Map[String, AnyRef]] ) extends StreamingDataWriterFactory { override def createWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = new Neo4jDataWriter( neo4j, jobId, partitionId, schema, saveMode, options, scriptResult ) } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingPartitionReader.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.AccumulatorV2 import org.apache.spark.util.LongAccumulator import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.util.Neo4jOptions class Neo4jStreamingPartitionReader( private val neo4j: Neo4j, private val options: Neo4jOptions, private val filters: Array[Filter], private val schema: StructType, private val jobId: String, private val partitionSkipLimit: PartitionPagination, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val requiredColumns: StructType, private val aggregateColumns: Array[AggregateFunc] ) extends BaseStreamingPartitionReader( neo4j, options, filters, schema, jobId, partitionSkipLimit, scriptResult, requiredColumns, aggregateColumns ) with PartitionReader[InternalRow] {} ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingPartitionReaderFactory.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.AccumulatorV2 import org.apache.spark.util.LongAccumulator import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.PartitionPagination import org.neo4j.spark.util.Neo4jOptions case class Neo4jStreamingPartition(partitionSkipLimit: PartitionPagination, filters: Array[Filter]) extends InputPartition class Neo4jStreamingPartitionReaderFactory( private val neo4j: Neo4j, private val neo4jOptions: Neo4jOptions, private val schema: StructType, private val jobId: String, private val scriptResult: java.util.List[java.util.Map[String, AnyRef]], private val aggregateColumns: Array[AggregateFunc] ) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = new Neo4jStreamingPartitionReader( neo4j, neo4jOptions, partition.asInstanceOf[Neo4jStreamingPartition].filters, schema, jobId, partition.asInstanceOf[Neo4jStreamingPartition].partitionSkipLimit, scriptResult, new StructType(), aggregateColumns ) } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingWriter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.streaming import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.write.PhysicalWriteInfo import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.Neo4jUtil import java.util.Optional class Neo4jStreamingWriter( val neo4j: Neo4j, val queryId: String, val schema: StructType, saveMode: SaveMode, val neo4jOptions: Neo4jOptions ) extends StreamingWrite { private val self = this private val listener = new StreamingQueryListener { override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = () override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = () override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { if (event.id.toString == queryId) { self.close() SparkSession.getDefaultSession.get.streams.removeListener(this) } } } SparkSession.getDefaultSession.get.streams.addListener(listener) private val driverCache = new DriverCache(neo4jOptions.connection) private lazy val scriptResult = { val schemaService = new SchemaService(neo4j, neo4jOptions, driverCache) schemaService.createOptimizations(schema) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() scriptResult } override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { new Neo4jStreamingDataWriterFactory( neo4j, queryId, schema, saveMode, neo4jOptions, scriptResult ) } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} def close(): Unit = Neo4jUtil.closeSafely(driverCache) } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connector.write.BatchWrite import org.apache.spark.sql.connector.write.DataWriterFactory import org.apache.spark.sql.connector.write.PhysicalWriteInfo import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.service.SchemaService import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import java.util.Optional class Neo4jBatchWriter( neo4j: Neo4j, jobId: String, structType: StructType, saveMode: SaveMode, neo4jOptions: Neo4jOptions ) extends BatchWrite { override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = { val schemaService = new SchemaService(neo4j, neo4jOptions, driverCache) schemaService.createOptimizations(structType) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() if (neo4jOptions.indexAwait > 0) { val session = driverCache.getOrCreate().session(neo4jOptions.session.toNeo4jSession()) session.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume() } new Neo4jDataWriterFactory( neo4j, jobId, structType, saveMode, neo4jOptions, scriptResult ) } private val driverCache = new DriverCache(neo4jOptions.connection) override def commit(messages: Array[WriterCommitMessage]): Unit = { driverCache.close() } override def abort(messages: Array[WriterCommitMessage]): Unit = { driverCache.close() } } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jDataWriter.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.util.Neo4jOptions class Neo4jDataWriter( neo4j: Neo4j, jobId: String, partitionId: Int, schema: StructType, saveMode: SaveMode, options: Neo4jOptions, scriptResult: java.util.List[java.util.Map[String, AnyRef]] ) extends BaseDataWriter(neo4j, jobId, partitionId, schema, saveMode, options, scriptResult) with DataWriter[InternalRow] {} ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jDataWriterFactory.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.connector.write.DataWriterFactory import org.apache.spark.sql.connector.write.PhysicalWriteInfo import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.util.Neo4jOptions class Neo4jDataWriterFactory( neo4j: Neo4j, jobId: String, structType: StructType, saveMode: SaveMode, options: Neo4jOptions, scriptResult: java.util.List[java.util.Map[String, AnyRef]] ) extends DataWriterFactory { override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = new Neo4jDataWriter( neo4j, jobId, partitionId, structType, saveMode, options, scriptResult ) } ================================================ FILE: spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jWriterBuilder.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark.writer import org.apache.spark.sql.SaveMode import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.BatchWrite import org.apache.spark.sql.connector.write.SupportsOverwrite import org.apache.spark.sql.connector.write.SupportsTruncate import org.apache.spark.sql.connector.write.Write import org.apache.spark.sql.connector.write.WriteBuilder import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.neo4j.caniuse.Neo4j import org.neo4j.spark.streaming.Neo4jStreamingWriter import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.NodeSaveMode import org.neo4j.spark.util.ValidateSaveMode import org.neo4j.spark.util.ValidateWrite import org.neo4j.spark.util.ValidationUtil import org.neo4j.spark.util.Validations class Neo4jWriterBuilder( neo4j: Neo4j, queryId: String, schema: StructType, saveMode: SaveMode, neo4jOptions: Neo4jOptions ) extends WriteBuilder with SupportsOverwrite with SupportsTruncate { override def build(): Write = new Write { override def description(): String = "Neo4j Writer" override def toBatch: BatchWrite = buildForBatch() override def toStreaming: StreamingWrite = buildForStreaming() override def supportedCustomMetrics(): Array[CustomMetric] = DataWriterMetrics.metricDeclarations() } def validOptions(actualSaveMode: SaveMode): Neo4jOptions = { Validations.validate(ValidateWrite( neo4j, neo4jOptions, queryId, actualSaveMode, (o: Neo4jOptions) => { ValidationUtil.isFalse( o.relationshipMetadata.sourceSaveMode.equals(NodeSaveMode.ErrorIfExists) && o.relationshipMetadata.targetSaveMode.equals(NodeSaveMode.ErrorIfExists), "Save mode 'ErrorIfExists' is not supported on Spark 3.0, use 'Append' instead." ) } )) neo4jOptions } override def buildForBatch(): BatchWrite = new Neo4jBatchWriter(neo4j, queryId, schema, saveMode, validOptions(saveMode)) @volatile private var streamWriter: Neo4jStreamingWriter = _ def isNewInstance(queryId: String, schema: StructType, options: Neo4jOptions): Boolean = streamWriter == null || streamWriter.queryId != queryId || streamWriter.schema != schema || streamWriter.neo4jOptions != options override def buildForStreaming(): StreamingWrite = { if (isNewInstance(queryId, schema, neo4jOptions)) { val streamingSaveMode = neo4jOptions.saveMode Validations.validate(ValidateSaveMode(streamingSaveMode)) val saveMode = SaveMode.valueOf(streamingSaveMode) streamWriter = new Neo4jStreamingWriter( neo4j, queryId, schema, saveMode, validOptions(saveMode) ) } streamWriter } override def overwrite(filters: Array[Filter]): WriteBuilder = { new Neo4jWriterBuilder(neo4j, queryId, schema, SaveMode.Overwrite, neo4jOptions) } } ================================================ FILE: spark-3/src/test/java/org/neo4j/spark/DataSourceReaderTypesTSE.java ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.junit.Test; import org.neo4j.driver.Session; import java.sql.Timestamp; import java.time.*; import java.sql.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.TimeZone; import static org.junit.Assert.assertEquals; public class DataSourceReaderTypesTSE extends SparkConnectorScalaBaseTSE { @Test public void testReadNodeWithString() { String name = "Foobar"; Dataset df = initTest("CREATE (p:Person {name: '" + name + "'})"); assertEquals(name, df.select("name").collectAsList().get(0).getString(0)); } @Test public void testReadNodeWithLong() { long age = 42L; Dataset df = initTest("CREATE (p:Person {age: " + age + "})"); assertEquals(age, df.select("age").collectAsList().get(0).getLong(0)); } @Test public void testReadNodeWithDouble() { double score = 4.2; Dataset df = initTest("CREATE (p:Person {score: " + score + "})"); assertEquals(score, df.select("score").collectAsList().get(0).getDouble(0), 0); } @Test public void testReadNodeWithLocalTime() { Dataset df = initTest("CREATE (p:Person {aTime: localtime({hour:12, minute: 23, second: 0, millisecond: 294})})"); GenericRowWithSchema result = df.select("aTime").collectAsList().get(0).getAs(0); assertEquals("local-time", result.get(0)); assertEquals("12:23:00.294", result.get(1)); } @Test public void testReadNodeWithTime() { TimeZone timezone = TimeZone.getDefault(); Dataset df = initTest("CREATE (p:Person {aTime: time({hour:12, minute: 23, second: 0, millisecond: 294, timezone: '" + timezone.getID() + "'})})"); GenericRowWithSchema result = df.select("aTime").collectAsList().get(0).getAs(0); LocalTime localTime = LocalTime.of(12, 23, 0, 294000000); OffsetTime expectedTime = OffsetTime.of(localTime, timezone.toZoneId().getRules().getOffset(Instant.now())); assertEquals("offset-time", result.get(0)); assertEquals(expectedTime.toString(), result.get(1)); } @Test public void testReadNodeWithLocalDateTime() { String localDateTime = "2007-12-03T10:15:30"; Dataset df = initTest("CREATE (p:Person {aTime: localdatetime('" + localDateTime + "')})"); Row row = df.select("aTime").collectAsList().get(0); LocalDateTime result = row.getAs(0); assertEquals(LocalDateTime.parse(localDateTime), result); } @Test public void testReadNodeWithZonedDateTime() { String dateTime = "2015-06-24T12:50:35.556+01:00"; Dataset df = initTest("CREATE (p:Person {aTime: datetime('" + dateTime + "')})"); Timestamp result = df.select("aTime").collectAsList().get(0).getTimestamp(0); assertEquals(Timestamp.from(OffsetDateTime.parse(dateTime).toInstant()), result); } @Test public void testReadNodeWithPoint() { Dataset df = initTest("CREATE (p:Person {location: point({x: 12.12, y: 13.13})})"); GenericRowWithSchema res = df.select("location").collectAsList().get(0).getAs(0); assertEquals("point-2d", res.get(0)); assertEquals(7203, res.get(1)); assertEquals(12.12, res.get(2)); assertEquals(13.13, res.get(3)); } @Test public void testReadNodeWithGeoPoint() { Dataset df = initTest("CREATE (p:Person {location: point({longitude: 12.12, latitude: 13.13})})"); GenericRowWithSchema res = df.select("location").collectAsList().get(0).getAs(0); assertEquals("point-2d", res.get(0)); assertEquals(4326, res.get(1)); assertEquals(12.12, res.get(2)); assertEquals(13.13, res.get(3)); } @Test public void testReadNodeWithPoint3D() { Dataset df = initTest("CREATE (p:Person {location: point({x: 12.12, y: 13.13, z: 1})})"); GenericRowWithSchema res = df.select("location").collectAsList().get(0).getAs(0); assertEquals("point-3d", res.get(0)); assertEquals(9157, res.get(1)); assertEquals(12.12, res.get(2)); assertEquals(13.13, res.get(3)); assertEquals(1.0, res.get(4)); } @Test public void testReadNodeWithDate() { Dataset df = initTest("CREATE (p:Person {born: date('2009-10-10')})"); Date res = df.select("born").collectAsList().get(0).getDate(0); assertEquals(Date.valueOf("2009-10-10"), res); } @Test public void testReadNodeWithDuration() { Dataset df = initTest("CREATE (p:Person {range: duration({days: 14, hours:16, minutes: 12})})"); GenericRowWithSchema res = df.select("range").collectAsList().get(0).getAs(0); assertEquals("duration", res.get(0)); assertEquals(0L, res.get(1)); assertEquals(14L, res.get(2)); assertEquals(58320L, res.get(3)); assertEquals(0, res.get(4)); assertEquals("P0M14DT58320S", res.get(5)); } @Test public void testReadNodeWithStringArray() { Dataset df = initTest("CREATE (p:Person {names: ['John', 'Doe']})"); List res = df.select("names").collectAsList().get(0).getList(0); assertEquals("John", res.get(0)); assertEquals("Doe", res.get(1)); } @Test public void testReadNodeWithLongArray() { Dataset df = initTest("CREATE (p:Person {ages: [22, 23]})"); List res = df.select("ages").collectAsList().get(0).getList(0); assertEquals(22L, res.get(0).longValue()); assertEquals(23L, res.get(1).longValue()); } @Test public void testReadNodeWithDoubleArray() { Dataset df = initTest("CREATE (p:Person {scores: [22.33, 44.55]})"); List res = df.select("scores").collectAsList().get(0).getList(0); assertEquals(22.33, res.get(0), 0); assertEquals(44.55, res.get(1), 0); } @Test public void testReadNodeWithLocalTimeArray() { Dataset df = initTest("CREATE (p:Person {someTimes: [localtime({hour:12}), localtime({hour:1, minute: 3})]})"); List res = df.select("someTimes").collectAsList().get(0).getList(0); assertEquals("local-time", res.get(0).get(0)); assertEquals("12:00:00", res.get(0).get(1)); assertEquals("local-time", res.get(1).get(0)); assertEquals("01:03:00", res.get(1).get(1)); } @Test public void testReadNodeWithBooleanArray() { Dataset df = initTest("CREATE (p:Person {bools: [true, false]})"); List res = df.select("bools").collectAsList().get(0).getList(0); assertEquals(true, res.get(0)); assertEquals(false, res.get(1)); } @Test public void testReadNodeWithArrayDate() { Dataset df = initTest("CREATE (p:Person {dates: [date('2009-10-10'), date('2009-10-11')]})"); List res = df.select("dates").collectAsList().get(0).getList(0); assertEquals(Date.valueOf("2009-10-10"), res.get(0)); assertEquals(Date.valueOf("2009-10-11"), res.get(1)); } @Test public void testReadNodeWithArrayZonedDateTime() { String datetime1 = "2015-06-24T12:50:35.556+01:00"; String datetime2 = "2015-06-23T12:50:35.556+01:00"; Dataset df = initTest("CREATE (p:Person {dates: [datetime('" + datetime1 + "'), datetime('" + datetime2 + "')]})"); List res = df.select("dates").collectAsList().get(0).getList(0); assertEquals(Timestamp.from(OffsetDateTime.parse(datetime1).toInstant()), res.get(0)); assertEquals(Timestamp.from(OffsetDateTime.parse(datetime2).toInstant()), res.get(1)); } @Test public void testReadNodeWithArrayDurations() { Dataset df = initTest("CREATE (p:Person {durations: [duration({months: 0.75}), duration({weeks: 2.5})]})"); List res = df.select("durations").collectAsList().get(0).getList(0); assertEquals("duration", res.get(0).get(0)); assertEquals(0L, res.get(0).get(1)); assertEquals(22L, res.get(0).get(2)); assertEquals(71509L, res.get(0).get(3)); assertEquals(500000000, res.get(0).get(4)); assertEquals("P0M22DT71509.500000000S", res.get(0).get(5)); assertEquals("duration", res.get(1).get(0)); assertEquals(0L, res.get(1).get(1)); assertEquals(17L, res.get(1).get(2)); assertEquals(43200L, res.get(1).get(3)); assertEquals(0, res.get(1).get(4)); assertEquals("P0M17DT43200S", res.get(1).get(5)); } @Test public void testReadNodeWithPointArray() { Dataset df = initTest("CREATE (p:Person {locations: [point({x: 11, y: 33.111}), point({x: 22, y: 44.222})]})"); List res = df.select("locations").collectAsList().get(0).getList(0); assertEquals("point-2d", res.get(0).get(0)); assertEquals(7203, res.get(0).get(1)); assertEquals(11.0, res.get(0).get(2)); assertEquals(33.111, res.get(0).get(3)); assertEquals("point-2d", res.get(1).get(0)); assertEquals(7203, res.get(1).get(1)); assertEquals(22.0, res.get(1).get(2)); assertEquals(44.222, res.get(1).get(3)); } @Test public void testReadNodeWithGeoPointArray() { Dataset df = initTest("CREATE (p:Person {locations: [point({longitude: 11, latitude: 33.111}), point({longitude: 22, latitude: 44.222})]})"); List res = df.select("locations").collectAsList().get(0).getList(0); assertEquals("point-2d", res.get(0).get(0)); assertEquals(4326, res.get(0).get(1)); assertEquals(11.0, res.get(0).get(2)); assertEquals(33.111, res.get(0).get(3)); assertEquals("point-2d", res.get(1).get(0)); assertEquals(4326, res.get(1).get(1)); assertEquals(22.0, res.get(1).get(2)); assertEquals(44.222, res.get(1).get(3)); } @Test public void testReadNodeWithPoint3DArray() { Dataset df = initTest("CREATE (p:Person {locations: [point({x: 11, y: 33.111, z: 12}), point({x: 22, y: 44.222, z: 99.1})]})"); List res = df.select("locations").collectAsList().get(0).getList(0); assertEquals("point-3d", res.get(0).get(0)); assertEquals(9157, res.get(0).get(1)); assertEquals(11.0, res.get(0).get(2)); assertEquals(33.111, res.get(0).get(3)); assertEquals("point-3d", res.get(1).get(0)); assertEquals(9157, res.get(1).get(1)); assertEquals(22.0, res.get(1).get(2)); assertEquals(44.222, res.get(1).get(3)); } @Test public void testReadNodeWithMap() { Dataset df = ss().read().format(DataSource.class.getName()) .option("url", SparkConnectorScalaSuiteIT.server().getBoltUrl()) .option("query", "RETURN {a: 1, b: '3'} AS map") .load(); Map map = df.select("map").collectAsList().get(0).getJavaMap(0); Map expectedMap = new HashMap<>(); expectedMap.put("a", "1"); expectedMap.put("b", "3"); assertEquals(expectedMap, map); } Dataset initTest(String query) { try (Session session = SparkConnectorScalaSuiteIT.session("")) { session.writeTransaction(transaction -> transaction.run(query).consume()); } return ss().read().format(DataSource.class.getName()) .option("url", SparkConnectorScalaSuiteIT.server().getBoltUrl()) .option("labels", "Person") .load(); } } ================================================ FILE: spark-3/src/test/java/org/neo4j/spark/SparkConnectorSuiteIT.java ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark; import org.junit.runner.RunWith; import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ DataSourceReaderTypesTSE.class }) public class SparkConnectorSuiteIT extends SparkConnectorScalaSuiteIT { } ================================================ FILE: spark-3/src/test/resources/log4j2.properties ================================================ appender.console.type=Console appender.console.name=STDOUT_TXT appender.console.target=SYSTEM_OUT appender.console.layout.type=PatternLayout appender.console.layout.pattern=%d{HH:mm:ss.SSS} [%t] %-5level %logger %X{source_type} - %msg%n logger.neo4j.name=org.neo4j logger.neo4j.level=debug rootLogger.level=info rootLogger.appenderRefs=stdout rootLogger.appenderRef.stdout.ref=STDOUT_TXT ================================================ FILE: spark-3/src/test/resources/neo4j-sso-test-realm.json ================================================ { "id": "7d10aebd-a60d-45a9-bce3-bb3a0a372ae9", "realm": "neo4j-sso-test", "notBefore": 0, "defaultSignatureAlgorithm": "RS256", "revokeRefreshToken": false, "refreshTokenMaxReuse": 0, "accessTokenLifespan": 300, "accessTokenLifespanForImplicitFlow": 900, "ssoSessionIdleTimeout": 1800, "ssoSessionMaxLifespan": 36000, "ssoSessionIdleTimeoutRememberMe": 0, "ssoSessionMaxLifespanRememberMe": 0, "offlineSessionIdleTimeout": 2592000, "offlineSessionMaxLifespanEnabled": false, "offlineSessionMaxLifespan": 5184000, "clientSessionIdleTimeout": 0, "clientSessionMaxLifespan": 0, "clientOfflineSessionIdleTimeout": 0, "clientOfflineSessionMaxLifespan": 0, "accessCodeLifespan": 60, "accessCodeLifespanUserAction": 300, "accessCodeLifespanLogin": 1800, "actionTokenGeneratedByAdminLifespan": 43200, "actionTokenGeneratedByUserLifespan": 300, "oauth2DeviceCodeLifespan": 600, "oauth2DevicePollingInterval": 5, "enabled": true, "sslRequired": "external", "registrationAllowed": false, "registrationEmailAsUsername": false, "rememberMe": false, "verifyEmail": false, "loginWithEmailAllowed": true, "duplicateEmailsAllowed": false, "resetPasswordAllowed": false, "editUsernameAllowed": false, "bruteForceProtected": false, "permanentLockout": false, "maxTemporaryLockouts": 0, "bruteForceStrategy": "MULTIPLE", "maxFailureWaitSeconds": 900, "minimumQuickLoginWaitSeconds": 60, "waitIncrementSeconds": 60, "quickLoginCheckMilliSeconds": 1000, "maxDeltaTimeSeconds": 43200, "failureFactor": 30, "roles": { "realm": [ { "id": "4d08fd99-cc5d-45d6-9a49-cdd4fcd6448e", "name": "uma_authorization", "description": "${role_uma_authorization}", "composite": false, "clientRole": false, "containerId": "7d10aebd-a60d-45a9-bce3-bb3a0a372ae9", "attributes": {} }, { "id": "49ad4e27-73f3-48ea-adf1-ff56880176d9", "name": "offline_access", "description": "${role_offline-access}", "composite": false, "clientRole": false, "containerId": "7d10aebd-a60d-45a9-bce3-bb3a0a372ae9", "attributes": {} }, { "id": "1c26e8ea-3815-4364-b254-39280413343b", "name": "default-roles-neo4j-sso-test", "description": "${role_default-roles}", "composite": true, "composites": { "realm": [ "offline_access", "uma_authorization" ], "client": { "account": [ "manage-account", "view-profile" ] } }, "clientRole": false, "containerId": "7d10aebd-a60d-45a9-bce3-bb3a0a372ae9", "attributes": {} } ], "client": { "realm-management": [ { "id": "02f4e601-c963-4fdd-9246-f732ceb7ff40", "name": "manage-users", "description": "${role_manage-users}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "860945bd-3a1b-4f1d-91e4-6506ed427e72", "name": "query-clients", "description": "${role_query-clients}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "0b3a6477-57be-4ebe-88d0-473f9bbcaa43", "name": "query-realms", "description": "${role_query-realms}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "07e63113-153b-4ca1-95c7-a0c161b6107a", "name": "realm-admin", "description": "${role_realm-admin}", "composite": true, "composites": { "client": { "realm-management": [ "manage-users", "query-clients", "query-realms", "view-clients", "query-users", "view-users", "manage-clients", "view-events", "query-groups", "manage-realm", "manage-authorization", "create-client", "manage-events", "view-realm", "manage-identity-providers", "impersonation", "view-authorization", "view-identity-providers" ] } }, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "89500893-787c-47a0-8e54-917d54508cff", "name": "query-users", "description": "${role_query-users}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "29194689-d523-40ef-9340-88eb86c6c4bb", "name": "view-clients", "description": "${role_view-clients}", "composite": true, "composites": { "client": { "realm-management": [ "query-clients" ] } }, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "e871e84c-d6d2-4c3f-bd7b-b9a5aeac32be", "name": "manage-clients", "description": "${role_manage-clients}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "8f862a23-e90d-4f92-8acf-fb9adb5121f0", "name": "view-users", "description": "${role_view-users}", "composite": true, "composites": { "client": { "realm-management": [ "query-groups", "query-users" ] } }, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "f36e6ae9-043d-47ee-b294-2f450d0f34a9", "name": "manage-realm", "description": "${role_manage-realm}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "802b1f48-95cc-4af5-bc96-76d8a701ec24", "name": "query-groups", "description": "${role_query-groups}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "743ede07-a12d-4d6b-bec2-0231450a28f9", "name": "view-events", "description": "${role_view-events}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "7252234e-695b-47a2-bfa9-7c50b5bedd35", "name": "manage-authorization", "description": "${role_manage-authorization}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "3d5a003d-3c0d-4193-9048-40179baa1f5f", "name": "create-client", "description": "${role_create-client}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "174cb3a3-b04a-446b-88d8-d5225a658757", "name": "manage-events", "description": "${role_manage-events}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "1596f494-4617-4794-ba44-d7add890691e", "name": "view-realm", "description": "${role_view-realm}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "daf38dba-70c6-4a38-9a56-92d42bcc4731", "name": "impersonation", "description": "${role_impersonation}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "64ea62ef-c69c-400f-af0d-f63f410a5801", "name": "manage-identity-providers", "description": "${role_manage-identity-providers}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "f6189b24-7b3b-4acf-8240-5297b60a82e5", "name": "view-authorization", "description": "${role_view-authorization}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} }, { "id": "ef2dea85-dd33-4aed-932b-a8e8bd2355ca", "name": "view-identity-providers", "description": "${role_view-identity-providers}", "composite": false, "clientRole": true, "containerId": "ff825411-d253-459a-93b9-71b9f00b1276", "attributes": {} } ], "security-admin-console": [], "admin-cli": [], "neo4j-commons-client": [ { "id": "9a142f12-99da-442c-bdaa-9f8a59171404", "name": "uma_protection", "composite": false, "clientRole": true, "containerId": "004fd711-5103-47bc-9c93-bfcf38b3a10a", "attributes": {} } ], "account-console": [], "broker": [ { "id": "06e885bf-8d8f-4a7b-afde-4bf4a87959d3", "name": "read-token", "description": "${role_read-token}", "composite": false, "clientRole": true, "containerId": "7c1f30ad-9cfd-4817-9f02-7f914b125de7", "attributes": {} } ], "account": [ { "id": "1f367fdb-6d7a-4c7f-af5b-a0ed66b1c289", "name": "delete-account", "description": "${role_delete-account}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "3fad9c59-832e-48d4-abc8-82a621a66ce9", "name": "manage-account", "description": "${role_manage-account}", "composite": true, "composites": { "client": { "account": [ "manage-account-links" ] } }, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "1d0cc167-ae58-4950-97e8-9f709ee118a6", "name": "view-consent", "description": "${role_view-consent}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "1d83e541-8475-48e1-afd4-6fa17429275c", "name": "manage-account-links", "description": "${role_manage-account-links}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "4e1703b0-d6f2-4504-aa4e-8670b14e88fe", "name": "view-applications", "description": "${role_view-applications}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "cd8d0800-a6b3-43fc-bd15-0c7ce40c9183", "name": "view-profile", "description": "${role_view-profile}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "62439512-e52e-41ef-9123-66d79931f45b", "name": "manage-consent", "description": "${role_manage-consent}", "composite": true, "composites": { "client": { "account": [ "view-consent" ] } }, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} }, { "id": "d67ceb33-f9d2-482f-99c4-09a29210ee78", "name": "view-groups", "description": "${role_view-groups}", "composite": false, "clientRole": true, "containerId": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "attributes": {} } ] } }, "groups": [ { "id": "667caafa-3dbe-4604-a427-6c61ecac5c3f", "name": "admin", "path": "/admin", "subGroups": [], "attributes": {}, "realmRoles": [], "clientRoles": {} } ], "defaultRole": { "id": "1c26e8ea-3815-4364-b254-39280413343b", "name": "default-roles-neo4j-sso-test", "description": "${role_default-roles}", "composite": true, "clientRole": false, "containerId": "7d10aebd-a60d-45a9-bce3-bb3a0a372ae9" }, "requiredCredentials": [ "password" ], "otpPolicyType": "totp", "otpPolicyAlgorithm": "HmacSHA1", "otpPolicyInitialCounter": 0, "otpPolicyDigits": 6, "otpPolicyLookAheadWindow": 1, "otpPolicyPeriod": 30, "otpPolicyCodeReusable": false, "otpSupportedApplications": [ "totpAppFreeOTPName", "totpAppGoogleName", "totpAppMicrosoftAuthenticatorName" ], "localizationTexts": {}, "webAuthnPolicyRpEntityName": "keycloak", "webAuthnPolicySignatureAlgorithms": [ "ES256", "RS256" ], "webAuthnPolicyRpId": "", "webAuthnPolicyAttestationConveyancePreference": "not specified", "webAuthnPolicyAuthenticatorAttachment": "not specified", "webAuthnPolicyRequireResidentKey": "not specified", "webAuthnPolicyUserVerificationRequirement": "not specified", "webAuthnPolicyCreateTimeout": 0, "webAuthnPolicyAvoidSameAuthenticatorRegister": false, "webAuthnPolicyAcceptableAaguids": [], "webAuthnPolicyExtraOrigins": [], "webAuthnPolicyPasswordlessRpEntityName": "keycloak", "webAuthnPolicyPasswordlessSignatureAlgorithms": [ "ES256", "RS256" ], "webAuthnPolicyPasswordlessRpId": "", "webAuthnPolicyPasswordlessAttestationConveyancePreference": "not specified", "webAuthnPolicyPasswordlessAuthenticatorAttachment": "not specified", "webAuthnPolicyPasswordlessRequireResidentKey": "not specified", "webAuthnPolicyPasswordlessUserVerificationRequirement": "not specified", "webAuthnPolicyPasswordlessCreateTimeout": 0, "webAuthnPolicyPasswordlessAvoidSameAuthenticatorRegister": false, "webAuthnPolicyPasswordlessAcceptableAaguids": [], "webAuthnPolicyPasswordlessExtraOrigins": [], "users" : [ { "id" : "14236e67-7ec2-490f-9f25-9e14af69e52a", "username" : "john-tester", "firstName" : "John", "lastName" : "Tester", "email" : "john.tester@test.com", "emailVerified" : false, "createdTimestamp" : 1753905058600, "enabled" : true, "totp" : false, "credentials" : [ { "id" : "d0fa2175-584f-45ec-a242-283b54a84feb", "type" : "password", "userLabel" : "My password", "createdDate" : 1753905074445, "secretData" : "{\"value\":\"XvUWimFjFwz3KoXoChpHcbypIlm/I1K/s4nYHNHQfMU=\",\"salt\":\"yEQApqyMC2yQCdU/7uLJzQ==\",\"additionalParameters\":{}}", "credentialData" : "{\"hashIterations\":5,\"algorithm\":\"argon2\",\"additionalParameters\":{\"hashLength\":[\"32\"],\"memory\":[\"7168\"],\"type\":[\"id\"],\"version\":[\"1.3\"],\"parallelism\":[\"1\"]}}" } ], "disableableCredentialTypes" : [ ], "requiredActions" : [ ], "realmRoles" : [ "default-roles-neo4j-sso-test" ], "notBefore" : 0, "groups" : [ "/admin" ] }, { "id" : "61144f9e-1bff-4b3e-beeb-a04d3e0cbe1c", "username" : "service-account-neo4j-commons-client", "emailVerified" : false, "createdTimestamp" : 1753904616478, "enabled" : true, "totp" : false, "serviceAccountClientId" : "neo4j-commons-client", "credentials" : [ ], "disableableCredentialTypes" : [ ], "requiredActions" : [ ], "realmRoles" : [ "default-roles-neo4j-sso-test" ], "clientRoles" : { "neo4j-commons-client" : [ "uma_protection" ] }, "notBefore" : 0, "groups" : [ ] } ], "scopeMappings": [ { "clientScope": "offline_access", "roles": [ "offline_access" ] } ], "clientScopeMappings": { "account": [ { "client": "account-console", "roles": [ "manage-account", "view-groups" ] } ] }, "clients": [ { "id": "f78d45f3-369c-457e-ab92-4e1e5a97f54c", "clientId": "account", "name": "${client_account}", "rootUrl": "${authBaseUrl}", "baseUrl": "/realms/neo4j-sso-test/account/", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [ "/realms/neo4j-sso-test/account/*" ], "webOrigins": [], "notBefore": 0, "bearerOnly": false, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": false, "serviceAccountsEnabled": false, "publicClient": true, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "false", "post.logout.redirect.uris": "+" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": false, "nodeReRegistrationTimeout": 0, "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "60b52f12-1c88-418b-8de5-181f69a7f12f", "clientId": "account-console", "name": "${client_account-console}", "rootUrl": "${authBaseUrl}", "baseUrl": "/realms/neo4j-sso-test/account/", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [ "/realms/neo4j-sso-test/account/*" ], "webOrigins": [], "notBefore": 0, "bearerOnly": false, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": false, "serviceAccountsEnabled": false, "publicClient": true, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "false", "post.logout.redirect.uris": "+", "pkce.code.challenge.method": "S256" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": false, "nodeReRegistrationTimeout": 0, "protocolMappers": [ { "id": "7281166b-a81b-485d-a227-b47299ea7a10", "name": "audience resolve", "protocol": "openid-connect", "protocolMapper": "oidc-audience-resolve-mapper", "consentRequired": false, "config": {} } ], "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "2c20b6a9-7bb6-4d41-b1bb-1375de214d5a", "clientId": "admin-cli", "name": "${client_admin-cli}", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [], "webOrigins": [], "notBefore": 0, "bearerOnly": false, "consentRequired": false, "standardFlowEnabled": false, "implicitFlowEnabled": false, "directAccessGrantsEnabled": true, "serviceAccountsEnabled": false, "publicClient": true, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "false", "client.use.lightweight.access.token.enabled": "true", "post.logout.redirect.uris": "+" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": true, "nodeReRegistrationTimeout": 0, "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "7c1f30ad-9cfd-4817-9f02-7f914b125de7", "clientId": "broker", "name": "${client_broker}", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [], "webOrigins": [], "notBefore": 0, "bearerOnly": true, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": false, "serviceAccountsEnabled": false, "publicClient": false, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "true", "post.logout.redirect.uris": "+" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": false, "nodeReRegistrationTimeout": 0, "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "004fd711-5103-47bc-9c93-bfcf38b3a10a", "clientId": "neo4j-commons-client", "name": "neo4j-commons-client", "description": "", "rootUrl": "", "adminUrl": "", "baseUrl": "", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "secret": "QNrSpbh0mxhnlYlI21UcBaz3Htb734vi", "redirectUris": [ "/*" ], "webOrigins": [ "/*" ], "notBefore": 0, "bearerOnly": false, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": true, "serviceAccountsEnabled": true, "authorizationServicesEnabled": true, "publicClient": false, "frontchannelLogout": true, "protocol": "openid-connect", "attributes": { "access.token.lifespan" : "3", "client.secret.creation.time" : "1750236324", "request.object.signature.alg" : "any", "request.object.encryption.alg" : "any", "client.introspection.response.allow.jwt.claim.enabled" : "true", "standard.token.exchange.enabled" : "false", "frontchannel.logout.session.required" : "true", "oauth2.device.authorization.grant.enabled" : "true", "use.jwks.url" : "false", "backchannel.logout.revoke.offline.tokens" : "false", "use.refresh.tokens" : "true", "realm_client" : "false", "oidc.ciba.grant.enabled" : "false", "client.use.lightweight.access.token.enabled" : "false", "backchannel.logout.session.required" : "true", "request.object.required" : "not required", "client_credentials.use_refresh_token" : "true", "access.token.header.type.rfc9068" : "false", "tls.client.certificate.bound.access.tokens" : "false", "require.pushed.authorization.requests" : "false", "acr.loa.map" : "{}", "display.on.consent.screen" : "false", "request.object.encryption.enc" : "any", "token.response.type.bearer.lower-case" : "false" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": true, "nodeReRegistrationTimeout": -1, "protocolMappers" : [ { "id" : "ac56f3e2-8898-4040-a4ed-b1a4e3a20e52", "name" : "groups", "protocol" : "openid-connect", "protocolMapper" : "oidc-group-membership-mapper", "consentRequired" : false, "config" : { "full.path" : "false", "introspection.token.claim" : "true", "userinfo.token.claim" : "false", "multivalued" : "true", "id.token.claim" : "false", "lightweight.claim" : "false", "access.token.claim" : "true", "claim.name" : "groups" } } ], "defaultClientScopes": [ "service_account", "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "ff825411-d253-459a-93b9-71b9f00b1276", "clientId": "realm-management", "name": "${client_realm-management}", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [], "webOrigins": [], "notBefore": 0, "bearerOnly": true, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": false, "serviceAccountsEnabled": false, "publicClient": false, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "true", "post.logout.redirect.uris": "+" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": false, "nodeReRegistrationTimeout": 0, "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] }, { "id": "e45c1497-2c8f-47f2-8c11-42ad8c088e23", "clientId": "security-admin-console", "name": "${client_security-admin-console}", "rootUrl": "${authAdminUrl}", "baseUrl": "/admin/neo4j-sso-test/console/", "surrogateAuthRequired": false, "enabled": true, "alwaysDisplayInConsole": false, "clientAuthenticatorType": "client-secret", "redirectUris": [ "/admin/neo4j-sso-test/console/*" ], "webOrigins": [ "+" ], "notBefore": 0, "bearerOnly": false, "consentRequired": false, "standardFlowEnabled": true, "implicitFlowEnabled": false, "directAccessGrantsEnabled": false, "serviceAccountsEnabled": false, "publicClient": true, "frontchannelLogout": false, "protocol": "openid-connect", "attributes": { "realm_client": "false", "client.use.lightweight.access.token.enabled": "true", "post.logout.redirect.uris": "+", "pkce.code.challenge.method": "S256" }, "authenticationFlowBindingOverrides": {}, "fullScopeAllowed": true, "nodeReRegistrationTimeout": 0, "protocolMappers": [ { "id": "eb064310-4c24-4c0f-bf39-27d5767a2f21", "name": "locale", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "locale", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "locale", "jsonType.label": "String" } } ], "defaultClientScopes": [ "web-origins", "acr", "profile", "roles", "basic", "email" ], "optionalClientScopes": [ "address", "phone", "organization", "offline_access", "microprofile-jwt" ] } ], "clientScopes": [ { "id": "6daac52d-0f54-4cbb-b159-93cae1188e26", "name": "service_account", "description": "Specific scope for a client enabled for service accounts", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "false", "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "a7d31bbd-ca7b-4d61-ae3a-7d901eb84039", "name": "Client Host", "protocol": "openid-connect", "protocolMapper": "oidc-usersessionmodel-note-mapper", "consentRequired": false, "config": { "user.session.note": "clientHost", "introspection.token.claim": "true", "userinfo.token.claim": "true", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "clientHost", "jsonType.label": "String" } }, { "id": "8c417a42-c005-46ff-8e30-834ca3170169", "name": "Client ID", "protocol": "openid-connect", "protocolMapper": "oidc-usersessionmodel-note-mapper", "consentRequired": false, "config": { "user.session.note": "client_id", "introspection.token.claim": "true", "userinfo.token.claim": "true", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "client_id", "jsonType.label": "String" } }, { "id": "3d484d15-8210-417d-af09-5cdedab9c6da", "name": "Client IP Address", "protocol": "openid-connect", "protocolMapper": "oidc-usersessionmodel-note-mapper", "consentRequired": false, "config": { "user.session.note": "clientAddress", "introspection.token.claim": "true", "userinfo.token.claim": "true", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "clientAddress", "jsonType.label": "String" } } ] }, { "id": "93bbac6c-0437-46ad-8917-29b8f9a03376", "name": "phone", "description": "OpenID Connect built-in scope: phone", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "consent.screen.text": "${phoneScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "6695fac9-a5a7-4705-aadb-c57a09626042", "name": "phone number verified", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "phoneNumberVerified", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "phone_number_verified", "jsonType.label": "boolean" } }, { "id": "c5f8b96f-7586-4468-837a-d94f2eff1f37", "name": "phone number", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "phoneNumber", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "phone_number", "jsonType.label": "String" } } ] }, { "id": "0430729b-4f3e-44ce-8d2c-1eec8e077357", "name": "saml_organization", "description": "Organization Membership", "protocol": "saml", "attributes": { "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "d5918bc9-1660-402f-8952-92571a955835", "name": "organization", "protocol": "saml", "protocolMapper": "saml-organization-membership-mapper", "consentRequired": false, "config": {} } ] }, { "id": "aab0e3e4-5f85-4015-802d-d5286e9ad947", "name": "profile", "description": "OpenID Connect built-in scope: profile", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "consent.screen.text": "${profileScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "75f3fd4d-04b6-4bc0-876b-07f598dbb669", "name": "family name", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "lastName", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "family_name", "jsonType.label": "String" } }, { "id": "7e619c16-d618-4f7f-bc46-4ba4fc0d8f7b", "name": "full name", "protocol": "openid-connect", "protocolMapper": "oidc-full-name-mapper", "consentRequired": false, "config": { "id.token.claim": "true", "introspection.token.claim": "true", "access.token.claim": "true", "userinfo.token.claim": "true" } }, { "id": "8476d1e0-9dcc-4318-9a58-6f52d82a8b46", "name": "locale", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "locale", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "locale", "jsonType.label": "String" } }, { "id": "09e15aea-9ce2-420f-baac-76f358b8ac5b", "name": "username", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "username", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "preferred_username", "jsonType.label": "String" } }, { "id": "faf6d7d0-3418-44b4-ba41-5b6bfd105fdb", "name": "gender", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "gender", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "gender", "jsonType.label": "String" } }, { "id": "710d18f8-2c62-4909-b0ea-99975876e2ae", "name": "middle name", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "middleName", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "middle_name", "jsonType.label": "String" } }, { "id": "1871201c-4d1e-4893-8775-3bd4ee6ee4ab", "name": "picture", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "picture", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "picture", "jsonType.label": "String" } }, { "id": "66d85a48-55c0-491d-b05a-9ba75093a6ab", "name": "updated at", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "updatedAt", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "updated_at", "jsonType.label": "long" } }, { "id": "130c0094-a76b-4235-a6a1-45fdae43b2ba", "name": "nickname", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "nickname", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "nickname", "jsonType.label": "String" } }, { "id": "8100f44e-0052-4568-8c8c-3964c2da1cfd", "name": "profile", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "profile", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "profile", "jsonType.label": "String" } }, { "id": "c7d9c5e2-8b82-4337-853a-b079c4d40cd4", "name": "given name", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "firstName", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "given_name", "jsonType.label": "String" } }, { "id": "cb9bb72d-3cfd-4ae7-98d1-567a132afbeb", "name": "birthdate", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "birthdate", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "birthdate", "jsonType.label": "String" } }, { "id": "713625e4-e1cb-4ba0-9339-983e6038fb7d", "name": "website", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "website", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "website", "jsonType.label": "String" } }, { "id": "5d44e729-70f0-41a0-b22b-2f51915accae", "name": "zoneinfo", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "zoneinfo", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "zoneinfo", "jsonType.label": "String" } } ] }, { "id": "655a4881-b5f1-4844-a60a-e5974b6dede9", "name": "email", "description": "OpenID Connect built-in scope: email", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "consent.screen.text": "${emailScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "1a203587-27fb-4c64-b1b2-6603bfc29a19", "name": "email verified", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-property-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "emailVerified", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "email_verified", "jsonType.label": "boolean" } }, { "id": "db51d38a-be10-40c1-a054-6124b17711c7", "name": "email", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "email", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "email", "jsonType.label": "String" } } ] }, { "id": "a2ebcfef-8d3f-493d-9deb-c0bbe0feffb4", "name": "acr", "description": "OpenID Connect scope for add acr (authentication context class reference) to the token", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "false", "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "58b572de-ed96-45d4-8b64-3193c03e8e00", "name": "acr loa level", "protocol": "openid-connect", "protocolMapper": "oidc-acr-mapper", "consentRequired": false, "config": { "id.token.claim": "true", "introspection.token.claim": "true", "access.token.claim": "true", "userinfo.token.claim": "true" } } ] }, { "id": "caa43308-3e68-4a0b-866a-ec350687d020", "name": "roles", "description": "OpenID Connect scope for add user roles to the access token", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "false", "consent.screen.text": "${rolesScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "263ca184-249e-40a5-88e6-e7062ffc416d", "name": "realm roles", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-realm-role-mapper", "consentRequired": false, "config": { "user.attribute": "foo", "introspection.token.claim": "true", "access.token.claim": "true", "claim.name": "realm_access.roles", "jsonType.label": "String", "multivalued": "true" } }, { "id": "e08c2804-bae8-4622-892f-bf002894d804", "name": "client roles", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-client-role-mapper", "consentRequired": false, "config": { "user.attribute": "foo", "introspection.token.claim": "true", "access.token.claim": "true", "claim.name": "resource_access.${client_id}.roles", "jsonType.label": "String", "multivalued": "true" } }, { "id": "ae8f0f35-d7a1-4f4d-94b9-a167c231d75b", "name": "audience resolve", "protocol": "openid-connect", "protocolMapper": "oidc-audience-resolve-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "access.token.claim": "true" } } ] }, { "id": "22c7fd6a-a90d-4482-85d3-dc5a51062bd4", "name": "organization", "description": "Additional claims about the organization a subject belongs to", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "consent.screen.text": "${organizationScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "79e6ae6c-c710-4ccf-a9d7-c1e8e8836d6a", "name": "organization", "protocol": "openid-connect", "protocolMapper": "oidc-organization-membership-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "multivalued": "true", "userinfo.token.claim": "true", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "organization", "jsonType.label": "String" } } ] }, { "id": "c85b8ffe-1a80-4012-9642-ddf25d112c89", "name": "basic", "description": "OpenID Connect scope for add all basic claims to the token", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "false", "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "d7f80a87-6368-47de-b8af-9fec687114fa", "name": "auth_time", "protocol": "openid-connect", "protocolMapper": "oidc-usersessionmodel-note-mapper", "consentRequired": false, "config": { "user.session.note": "AUTH_TIME", "introspection.token.claim": "true", "userinfo.token.claim": "true", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "auth_time", "jsonType.label": "long" } }, { "id": "2b319ed3-d20f-407c-9cbd-c5478ef4a9d1", "name": "sub", "protocol": "openid-connect", "protocolMapper": "oidc-sub-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "access.token.claim": "true" } } ] }, { "id": "2942a349-f522-43cc-b983-e1b2119af9db", "name": "role_list", "description": "SAML role list", "protocol": "saml", "attributes": { "consent.screen.text": "${samlRoleListScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "40a89e81-e9eb-463d-b5dc-b124e64ffa57", "name": "role list", "protocol": "saml", "protocolMapper": "saml-role-list-mapper", "consentRequired": false, "config": { "single": "false", "attribute.nameformat": "Basic", "attribute.name": "Role" } } ] }, { "id": "0af3471d-ddde-4e01-9f64-137725d70416", "name": "offline_access", "description": "OpenID Connect built-in scope: offline_access", "protocol": "openid-connect", "attributes": { "consent.screen.text": "${offlineAccessScopeConsentText}", "display.on.consent.screen": "true" } }, { "id": "a600ed15-1dc0-4c45-8262-1bd34bba9ca5", "name": "microprofile-jwt", "description": "Microprofile - JWT built-in scope", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "095068ea-7f54-4bf6-9cf2-331ea30368dc", "name": "upn", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-attribute-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "userinfo.token.claim": "true", "user.attribute": "username", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "upn", "jsonType.label": "String" } }, { "id": "0e12f986-9b38-475c-a18a-2da1e95f4831", "name": "groups", "protocol": "openid-connect", "protocolMapper": "oidc-usermodel-realm-role-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "multivalued": "true", "userinfo.token.claim": "true", "user.attribute": "foo", "id.token.claim": "true", "access.token.claim": "true", "claim.name": "groups", "jsonType.label": "String" } } ] }, { "id": "6235a6e8-21ee-4498-9c95-b3816efe0dfc", "name": "address", "description": "OpenID Connect built-in scope: address", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "true", "consent.screen.text": "${addressScopeConsentText}", "display.on.consent.screen": "true" }, "protocolMappers": [ { "id": "1b5f1955-723a-4ec5-a8eb-578fb9e9ebb0", "name": "address", "protocol": "openid-connect", "protocolMapper": "oidc-address-mapper", "consentRequired": false, "config": { "user.attribute.formatted": "formatted", "user.attribute.country": "country", "introspection.token.claim": "true", "user.attribute.postal_code": "postal_code", "userinfo.token.claim": "true", "user.attribute.street": "street", "id.token.claim": "true", "user.attribute.region": "region", "access.token.claim": "true", "user.attribute.locality": "locality" } } ] }, { "id": "a9966629-729a-4d02-931c-cdbccc2ef6e1", "name": "web-origins", "description": "OpenID Connect scope for add allowed web origins to the access token", "protocol": "openid-connect", "attributes": { "include.in.token.scope": "false", "consent.screen.text": "", "display.on.consent.screen": "false" }, "protocolMappers": [ { "id": "2038278c-eb80-42eb-828a-ce1a700bcf5d", "name": "allowed web origins", "protocol": "openid-connect", "protocolMapper": "oidc-allowed-origins-mapper", "consentRequired": false, "config": { "introspection.token.claim": "true", "access.token.claim": "true" } } ] } ], "defaultDefaultClientScopes": [ "role_list", "saml_organization", "profile", "email", "roles", "web-origins", "acr", "basic" ], "defaultOptionalClientScopes": [ "offline_access", "address", "phone", "microprofile-jwt", "organization" ], "browserSecurityHeaders": { "contentSecurityPolicyReportOnly": "", "xContentTypeOptions": "nosniff", "referrerPolicy": "no-referrer", "xRobotsTag": "none", "xFrameOptions": "SAMEORIGIN", "contentSecurityPolicy": "frame-src 'self'; frame-ancestors 'self'; object-src 'none';", "strictTransportSecurity": "max-age=31536000; includeSubDomains" }, "smtpServer": {}, "eventsEnabled": false, "eventsListeners": [ "jboss-logging" ], "enabledEventTypes": [], "adminEventsEnabled": false, "adminEventsDetailsEnabled": false, "identityProviders": [], "identityProviderMappers": [], "components": { "org.keycloak.services.clientregistration.policy.ClientRegistrationPolicy": [ { "id": "816e3705-90b1-4aba-988e-100ff1effac8", "name": "Allowed Client Scopes", "providerId": "allowed-client-templates", "subType": "authenticated", "subComponents": {}, "config": { "allow-default-scopes": [ "true" ] } }, { "id": "d17f76e9-3fc9-4e0b-ba62-1787cfaacf21", "name": "Trusted Hosts", "providerId": "trusted-hosts", "subType": "anonymous", "subComponents": {}, "config": { "host-sending-registration-request-must-match": [ "true" ], "client-uris-must-match": [ "true" ] } }, { "id": "de42e33e-cf5f-40c6-8c0d-7bb4fd7cb91b", "name": "Allowed Client Scopes", "providerId": "allowed-client-templates", "subType": "anonymous", "subComponents": {}, "config": { "allow-default-scopes": [ "true" ] } }, { "id": "a5989451-aad7-4b75-85cf-29340c3508d6", "name": "Allowed Protocol Mapper Types", "providerId": "allowed-protocol-mappers", "subType": "authenticated", "subComponents": {}, "config": { "allowed-protocol-mapper-types": [ "saml-user-property-mapper", "oidc-full-name-mapper", "oidc-address-mapper", "oidc-usermodel-property-mapper", "oidc-usermodel-attribute-mapper", "saml-user-attribute-mapper", "saml-role-list-mapper", "oidc-sha256-pairwise-sub-mapper" ] } }, { "id": "871c280c-5f17-4ec4-9fd6-f36171d56553", "name": "Max Clients Limit", "providerId": "max-clients", "subType": "anonymous", "subComponents": {}, "config": { "max-clients": [ "200" ] } }, { "id": "fbf6f1c2-aabf-40ef-b764-8a677603d2cc", "name": "Consent Required", "providerId": "consent-required", "subType": "anonymous", "subComponents": {}, "config": {} }, { "id": "a6dc5c32-28cd-424f-b8ec-db381d5a8fdd", "name": "Allowed Protocol Mapper Types", "providerId": "allowed-protocol-mappers", "subType": "anonymous", "subComponents": {}, "config": { "allowed-protocol-mapper-types": [ "oidc-address-mapper", "saml-user-attribute-mapper", "oidc-usermodel-attribute-mapper", "oidc-full-name-mapper", "saml-user-property-mapper", "oidc-sha256-pairwise-sub-mapper", "oidc-usermodel-property-mapper", "saml-role-list-mapper" ] } }, { "id": "798a0205-2a20-436d-b15f-ec19ee70a277", "name": "Full Scope Disabled", "providerId": "scope", "subType": "anonymous", "subComponents": {}, "config": {} } ], "org.keycloak.keys.KeyProvider": [ { "id": "702a660b-9d47-4d4e-b0eb-326775910293", "name": "aes-generated", "providerId": "aes-generated", "subComponents": {}, "config": { "priority": [ "100" ] } }, { "id": "025d31ab-f248-4358-87d0-ec0e27bee344", "name": "rsa-generated", "providerId": "rsa-generated", "subComponents": {}, "config": { "priority": [ "100" ] } }, { "id": "efec8cce-6c75-4e89-9017-f553468e2811", "name": "rsa-enc-generated", "providerId": "rsa-enc-generated", "subComponents": {}, "config": { "priority": [ "100" ], "algorithm": [ "RSA-OAEP" ] } }, { "id": "40d62ca1-93bc-4110-b2dc-7d7079c45774", "name": "hmac-generated-hs512", "providerId": "hmac-generated", "subComponents": {}, "config": { "priority": [ "100" ], "algorithm": [ "HS512" ] } } ] }, "internationalizationEnabled": false, "supportedLocales": [], "authenticationFlows": [ { "id": "3ce2ce6d-7743-4e1b-8a31-291fcc7ae06f", "alias": "Account verification options", "description": "Method with which to verity the existing account", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "idp-email-verification", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "ALTERNATIVE", "priority": 20, "autheticatorFlow": true, "flowAlias": "Verify Existing Account by Re-authentication", "userSetupAllowed": false } ] }, { "id": "8267719e-7856-4694-a9eb-f405eca2aca5", "alias": "Browser - Conditional OTP", "description": "Flow to determine if the OTP is required for the authentication", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "auth-otp-form", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "596aae9f-ae33-415a-b9e8-de70ee954ac5", "alias": "Browser - Conditional Organization", "description": "Flow to determine if the organization identity-first login is to be used", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "organization", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "068702f9-b6d9-489e-af34-8b9f325a361b", "alias": "Direct Grant - Conditional OTP", "description": "Flow to determine if the OTP is required for the authentication", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "direct-grant-validate-otp", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "0459a690-8543-422c-bc83-94e8c62c8bdd", "alias": "First Broker Login - Conditional Organization", "description": "Flow to determine if the authenticator that adds organization members is to be used", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "idp-add-organization-member", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "31cb5a9b-b319-45d7-980d-cc6cd8a9e579", "alias": "First broker login - Conditional OTP", "description": "Flow to determine if the OTP is required for the authentication", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "auth-otp-form", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "7448795c-59cf-4a43-9af8-a988cdc1e7b6", "alias": "Handle Existing Account", "description": "Handle what to do if there is existing account with same email/username like authenticated identity provider", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "idp-confirm-link", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": true, "flowAlias": "Account verification options", "userSetupAllowed": false } ] }, { "id": "c18be560-2806-4c78-942d-00d4e32b0339", "alias": "Organization", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 10, "autheticatorFlow": true, "flowAlias": "Browser - Conditional Organization", "userSetupAllowed": false } ] }, { "id": "498a92fe-00bb-4419-a9e1-c08dcc55c0b9", "alias": "Reset - Conditional OTP", "description": "Flow to determine if the OTP should be reset or not. Set to REQUIRED to force.", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "conditional-user-configured", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "reset-otp", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "f19dc4c2-1167-41f7-8573-ecf19f285c9e", "alias": "User creation or linking", "description": "Flow for the existing/non-existing user alternatives", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticatorConfig": "create unique user config", "authenticator": "idp-create-user-if-unique", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "ALTERNATIVE", "priority": 20, "autheticatorFlow": true, "flowAlias": "Handle Existing Account", "userSetupAllowed": false } ] }, { "id": "fb78126b-7892-4da6-adc0-58a9b74e0a0d", "alias": "Verify Existing Account by Re-authentication", "description": "Reauthentication of existing account", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "idp-username-password-form", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 20, "autheticatorFlow": true, "flowAlias": "First broker login - Conditional OTP", "userSetupAllowed": false } ] }, { "id": "2a393886-1bea-48bd-823b-485d4ae4a6a6", "alias": "browser", "description": "Browser based authentication", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "auth-cookie", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "auth-spnego", "authenticatorFlow": false, "requirement": "DISABLED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "identity-provider-redirector", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 25, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "ALTERNATIVE", "priority": 26, "autheticatorFlow": true, "flowAlias": "Organization", "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "ALTERNATIVE", "priority": 30, "autheticatorFlow": true, "flowAlias": "forms", "userSetupAllowed": false } ] }, { "id": "8df87113-974a-43b0-80cc-4a86a8eb907a", "alias": "clients", "description": "Base authentication for clients", "providerId": "client-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "client-secret", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "client-jwt", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "client-secret-jwt", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 30, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "client-x509", "authenticatorFlow": false, "requirement": "ALTERNATIVE", "priority": 40, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "f3116f33-e9e1-439d-8997-5496b244640e", "alias": "direct grant", "description": "OpenID Connect Resource Owner Grant", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "direct-grant-validate-username", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "direct-grant-validate-password", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 30, "autheticatorFlow": true, "flowAlias": "Direct Grant - Conditional OTP", "userSetupAllowed": false } ] }, { "id": "05f713a9-caaf-4f5f-8df0-3423ee1d3b2e", "alias": "docker auth", "description": "Used by Docker clients to authenticate against the IDP", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "docker-http-basic-authenticator", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "caa3d932-2aad-42c4-9e63-de8d32c8858c", "alias": "first broker login", "description": "Actions taken after first broker login with identity provider account, which is not yet linked to any Keycloak account", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticatorConfig": "review profile config", "authenticator": "idp-review-profile", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": true, "flowAlias": "User creation or linking", "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 50, "autheticatorFlow": true, "flowAlias": "First Broker Login - Conditional Organization", "userSetupAllowed": false } ] }, { "id": "eba73d10-65cf-4364-8969-92571fe69b72", "alias": "forms", "description": "Username, password, otp and other auth forms.", "providerId": "basic-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "auth-username-password-form", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 20, "autheticatorFlow": true, "flowAlias": "Browser - Conditional OTP", "userSetupAllowed": false } ] }, { "id": "e2d23ce6-e37d-46f1-83fd-b18076d81aeb", "alias": "registration", "description": "Registration flow", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "registration-page-form", "authenticatorFlow": true, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": true, "flowAlias": "registration form", "userSetupAllowed": false } ] }, { "id": "00d69e38-630a-4215-a635-e71849ae2abc", "alias": "registration form", "description": "Registration form", "providerId": "form-flow", "topLevel": false, "builtIn": true, "authenticationExecutions": [ { "authenticator": "registration-user-creation", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "registration-password-action", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 50, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "registration-recaptcha-action", "authenticatorFlow": false, "requirement": "DISABLED", "priority": 60, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "registration-terms-and-conditions", "authenticatorFlow": false, "requirement": "DISABLED", "priority": 70, "autheticatorFlow": false, "userSetupAllowed": false } ] }, { "id": "fe3aae15-d563-4576-a709-4ba3a40c540e", "alias": "reset credentials", "description": "Reset credentials for a user if they forgot their password or something", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "reset-credentials-choose-user", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "reset-credential-email", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 20, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticator": "reset-password", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 30, "autheticatorFlow": false, "userSetupAllowed": false }, { "authenticatorFlow": true, "requirement": "CONDITIONAL", "priority": 40, "autheticatorFlow": true, "flowAlias": "Reset - Conditional OTP", "userSetupAllowed": false } ] }, { "id": "ffedec01-f274-4b58-ab70-178ff4e8c569", "alias": "saml ecp", "description": "SAML ECP Profile Authentication Flow", "providerId": "basic-flow", "topLevel": true, "builtIn": true, "authenticationExecutions": [ { "authenticator": "http-basic-authenticator", "authenticatorFlow": false, "requirement": "REQUIRED", "priority": 10, "autheticatorFlow": false, "userSetupAllowed": false } ] } ], "authenticatorConfig": [ { "id": "3911f29c-6c6a-418c-93a8-72ba806d715b", "alias": "create unique user config", "config": { "require.password.update.after.registration": "false" } }, { "id": "dc544a68-9478-4679-8a61-76e9afb84757", "alias": "review profile config", "config": { "update.profile.on.first.login": "missing" } } ], "requiredActions": [ { "alias": "CONFIGURE_TOTP", "name": "Configure OTP", "providerId": "CONFIGURE_TOTP", "enabled": true, "defaultAction": false, "priority": 10, "config": {} }, { "alias": "TERMS_AND_CONDITIONS", "name": "Terms and Conditions", "providerId": "TERMS_AND_CONDITIONS", "enabled": false, "defaultAction": false, "priority": 20, "config": {} }, { "alias": "UPDATE_PASSWORD", "name": "Update Password", "providerId": "UPDATE_PASSWORD", "enabled": true, "defaultAction": false, "priority": 30, "config": {} }, { "alias": "UPDATE_PROFILE", "name": "Update Profile", "providerId": "UPDATE_PROFILE", "enabled": true, "defaultAction": false, "priority": 40, "config": {} }, { "alias": "VERIFY_EMAIL", "name": "Verify Email", "providerId": "VERIFY_EMAIL", "enabled": true, "defaultAction": false, "priority": 50, "config": {} }, { "alias": "delete_account", "name": "Delete Account", "providerId": "delete_account", "enabled": false, "defaultAction": false, "priority": 60, "config": {} }, { "alias": "webauthn-register", "name": "Webauthn Register", "providerId": "webauthn-register", "enabled": true, "defaultAction": false, "priority": 70, "config": {} }, { "alias": "webauthn-register-passwordless", "name": "Webauthn Register Passwordless", "providerId": "webauthn-register-passwordless", "enabled": true, "defaultAction": false, "priority": 80, "config": {} }, { "alias": "VERIFY_PROFILE", "name": "Verify Profile", "providerId": "VERIFY_PROFILE", "enabled": true, "defaultAction": false, "priority": 90, "config": {} }, { "alias": "delete_credential", "name": "Delete Credential", "providerId": "delete_credential", "enabled": true, "defaultAction": false, "priority": 100, "config": {} }, { "alias": "update_user_locale", "name": "Update User Locale", "providerId": "update_user_locale", "enabled": true, "defaultAction": false, "priority": 1000, "config": {} } ], "browserFlow": "browser", "registrationFlow": "registration", "directGrantFlow": "direct grant", "resetCredentialsFlow": "reset credentials", "clientAuthenticationFlow": "clients", "dockerAuthenticationFlow": "docker auth", "firstBrokerLoginFlow": "first broker login", "attributes": { "cibaBackchannelTokenDeliveryMode": "poll", "cibaExpiresIn": "120", "cibaAuthRequestedUserHint": "login_hint", "oauth2DeviceCodeLifespan": "600", "clientOfflineSessionMaxLifespan": "0", "oauth2DevicePollingInterval": "5", "clientSessionIdleTimeout": "0", "parRequestUriLifespan": "60", "clientSessionMaxLifespan": "0", "clientOfflineSessionIdleTimeout": "0", "cibaInterval": "5", "realmReusableOtpCode": "false" }, "keycloakVersion": "26.2.5", "userManagedAccessAllowed": false, "organizationsEnabled": false, "verifiableCredentialsEnabled": false, "adminPermissionsEnabled": false, "clientProfiles": { "profiles": [] }, "clientPolicies": { "policies": [] } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceAggregationTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.Assert.assertEquals import org.junit.Test import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary class DataSourceReaderAggregationTSE extends SparkConnectorScalaBaseTSE { @Test def testShouldDoSumAggregation(): Unit = { val fixtureQuery: String = s"""CREATE (pe:Person {id: 1, fullName: 'Person'})-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr:Product {id: 0, name: 'Product ' + 0, price: 1}) |WITH pe |UNWIND range(1, 10) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id, price: id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .createTempView("BOUGHT") val df = ss.sql( """SELECT `source.fullName`, SUM(DISTINCT(`target.price`)) AS distinctTotal, SUM(`target.price`) AS total |FROM BOUGHT |group by `source.fullName`""".stripMargin ) val rows = df.collect().toList assertEquals(1, rows.length) val row = rows(0) assertEquals("Person", row.getAs[String]("source.fullName")) assertEquals(55L, row.getAs[Long]("distinctTotal")) assertEquals(56L, row.getAs[Long]("total")) } @Test def testShouldDoMaxMinAggregation(): Unit = { val fixtureQuery: String = s"""CREATE (pe:Person {id: 1, fullName: 'Person'}) |WITH pe |UNWIND range(1, 10) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id, price: id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .createTempView("BOUGHT") val df = ss.sql( """SELECT `source.fullName`, MAX(`target.price`) AS max, MIN(`target.price`) AS min |FROM BOUGHT |GROUP BY `source.fullName`""".stripMargin ) val rows = df.collect().toList assertEquals(1, rows.length) val row = rows(0) assertEquals("Person", row.getAs[String]("source.fullName")) assertEquals(10L, row.getAs[Long]("max")) assertEquals(1L, row.getAs[Long]("min")) } @Test def testShouldDoCountAggregation(): Unit = { val fixtureQuery: String = s"""CREATE (pe:Person {id: 1, fullName: 'Person'})-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr:Product {id: 1, name: 'Product 1', price: 1}) |WITH pe |UNWIND range(1, 10) as id |MERGE (pr:Product {id: id, name: 'Product ' + id, price: id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .createTempView("BOUGHT") val df = ss.sql( """SELECT `source.fullName`, COUNT(DISTINCT(`target.id`)) AS distinctTotal, COUNT(`target.id`) AS total |FROM BOUGHT |group by `source.fullName`""".stripMargin ) val rows = df.collect().toList assertEquals(1, rows.length) val row = rows(0) assertEquals("Person", row.getAs[String]("source.fullName")) assertEquals(10L, row.getAs[Long]("distinctTotal")) assertEquals(11L, row.getAs[Long]("total")) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceReaderNeo4jTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkException import org.apache.spark.sql.DataFrame import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Assert.fail import org.junit.Assume import org.junit.BeforeClass import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.driver.SessionConfig import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.exceptions.ClientException import org.neo4j.driver.summary.ResultSummary class DataSourceReaderNeo4jTSE extends SparkConnectorScalaBaseTSE { @Test def testMultiDbJoin(): Unit = { SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db1")) .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run( """ CREATE (p1:Person:Customer {name: 'John Doe'}), (p2:Person:Customer {name: 'Mark Brown'}), (p3:Person:Customer {name: 'Cindy White'}) """ ).consume() } ) SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db2")) .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run( """ CREATE (p1:Person:Employee {name: 'Jane Doe'}), (p2:Person:Employee {name: 'John Doe'}) """ ).consume() } ) val df1 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("labels", "Person") .load() val df2 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("labels", "Person") .load() assertEquals(3, df1.count()) assertEquals(2, df2.count()) val dfJoin = df1.join(df2, df1("name") === df2("name")) assertEquals(1, dfJoin.count()) } @Test def testReadQueryCustomPartitions(): Unit = { val fixtureProduct1Query: String = """CREATE (pr:Product{id: 1, name: 'Product 1'}) |WITH pr |UNWIND range(1,100) as id |CREATE (p:Person {id: id, name: 'Person ' + id})-[:BOUGHT{quantity: ceil(rand() * 100)}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureProduct1Query).consume() } ) val fixtureProduct2Query: String = """CREATE (pr:Product{id: 2, name: 'Product 2'}) |WITH pr |UNWIND range(1,50) as id |MATCH (p:Person {id: id}) |CREATE (p)-[:BOUGHT{quantity: ceil(rand() * 100)}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureProduct2Query).consume() } ) val partitionedDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |MATCH (p:Person)-[r:BOUGHT]->(pr:Product) |RETURN p.name AS person, pr.name AS product, r.quantity AS quantity""".stripMargin ) .option("partitions", "5") .load() assertEquals(5, partitionedDf.rdd.getNumPartitions) val rows = partitionedDf.collect() .map(row => s"${row.getAs[String]("person")}-${row.getAs[String]("product")}") assertEquals(150, rows.size) assertEquals(150, rows.size) } @Test def testCallShouldReturnCorrectSchema(): Unit = { val callDf: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "CALL db.info() YIELD id, name RETURN *") .load() val res = callDf.select("name") .collectAsList() .get(0) assertEquals(res.getString(0), "neo4j") } @Test def testShouldReturnJustTheSelectedFieldWithNode(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .select("name") df.count() assertEquals(Seq("name"), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithNodeAndWeirdColumnName(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, `(╯°□°)╯︵ ┻━┻`: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .select("`(╯°□°)╯︵ ┻━┻`") df.count() assertEquals(Seq("(╯°□°)╯︵ ┻━┻"), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithRelationship(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Product") .option("relationship.target.labels", "Person") .load .select("`source.name`", "``") df.count() assertEquals(Seq("source.name", ""), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithRelationshipAndWeirdColumn(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), `(╯°□°)╯︵ ┻━┻`: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .select("`target.(╯°□°)╯︵ ┻━┻`", "``") df.count() assertEquals(Seq("target.(╯°□°)╯︵ ┻━┻", ""), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithQuery(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (p:Product) RETURN p.name as name") .option("partitions", 2) .option("query.count", 20) .load .select("name") df.count() assertEquals(Seq("name"), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithFilter(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .filter("name = 'Product 1'") df.count() assertEquals(Seq("", "", "name", "id"), df.columns.toSeq) } @Test def testShouldReturnJustTheSelectedFieldWithRelationshipWithFilter(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .filter("`target.name` = 'Product 1' AND `target.id` = '16'") .select("`target.name`", "`target.id`") df.count() assertEquals(Seq("target.name", "target.id"), df.columns.toSeq) } @Test def testShouldThrowClearErrorIfAWrongDbIsSpecified(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "not_existing_db") .option("labels", "MATCH (h:Household) RETURN id(h)") .load() .show() } catch { case clientException: ClientException => { assertTrue(clientException.getMessage.equals( "Database does not exist. Database name: 'not_existing_db'." )) } case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead") } } @Test def testEmptyDataset(): Unit = { val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (e:ID_DO_NOT_EXIST) RETURN id(e) as f, 1 as g") .load assertEquals(0, df.count()) assertEquals(Seq("f", "g"), df.columns.toSeq) } @Test def testColumnSorted(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (i1:Instrument{name: 'Drums', id: 1}), (i2:Instrument{name: 'Guitar', id: 2})").consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (i:Instrument) RETURN id(i) as internal_id, i.id as id, i.name as name, i.name") .load .orderBy("id") assertEquals(1L, df.collectAsList().get(0).get(1)) assertEquals("Drums", df.collectAsList().get(0).get(2)) assertEquals(Seq("internal_id", "id", "name", "i.name"), df.columns.toSeq) } @Test def testComplexReturnStatement(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin use(SparkConnectorScalaSuiteIT.session()) { session => session .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) } val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN id(p) AS personId, id(pr) AS productId, {quantity: b.quantity, when: b.when} AS map, "some string" as someString, {anotherField: "201"} as map2""".stripMargin ) .option("schema.strategy", "string") .load() assertEquals(Seq("personId", "productId", "map", "someString", "map2"), df.columns.toSeq) assertEquals(100, df.count()) } @Test def testComplexReturnStatementNoValues(): Unit = { val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN id(p) AS personId, id(pr) AS productId, {quantity: b.quantity, when: b.when} AS map, "some string" as someString, {anotherField: "201", and: 1} as map2""".stripMargin ) .option("schema.strategy", "string") .load() assertEquals(Seq("personId", "productId", "map", "someString", "map2"), df.columns.toSeq) assertEquals(0, df.count()) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceReaderNeo4jWithApocTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.Assert.assertEquals import org.junit.Assume import org.junit.BeforeClass import org.junit.Test import org.neo4j.driver.SessionConfig import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary class DataSourceReaderNeo4jWithApocTSE extends SparkConnectorScalaBaseWithApocTSE { @Test def testMultiDbJoin(): Unit = { SparkConnectorScalaSuiteWithApocIT.driver.session(SessionConfig.forDatabase("db1")) .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run( """ CREATE (p1:Person:Customer {name: 'John Doe'}), (p2:Person:Customer {name: 'Mark Brown'}), (p3:Person:Customer {name: 'Cindy White'}) """ ).consume() } ) SparkConnectorScalaSuiteWithApocIT.driver.session(SessionConfig.forDatabase("db2")) .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run( """ CREATE (p1:Person:Employee {name: 'Jane Doe'}), (p2:Person:Employee {name: 'John Doe'}) """ ).consume() } ) val df1 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("database", "db1") .option("labels", "Person") .load() val df2 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("database", "db2") .option("labels", "Person") .load() assertEquals(3, df1.count()) assertEquals(2, df2.count()) val dfJoin = df1.join(df2, df1("name") === df2("name")) assertEquals(1, dfJoin.count()) } @Test def testReturnProcedure(): Unit = { val query = """RETURN apoc.convert.toSet([1,1,3]) AS foo, 'bar' AS bar |""".stripMargin val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("partitions", 1) .option("query", query) .load assertEquals(Seq("foo", "bar"), df.columns.toSeq) // ordering should be preserved assertEquals(1, df.count()) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceReaderTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.Assert._ import org.junit.Test import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import java.sql.Timestamp import java.time._ import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.Seq class DataSourceReaderTSE extends SparkConnectorScalaBaseTSE { @Test def testThrowsExceptionIfNoValidReadOptionIsSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case e: IllegalArgumentException => assertEquals("No valid option found. One of `GDS`, `LABELS`, `QUERY`, `RELATIONSHIP` is required", e.getMessage) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testThrowsExceptionIfTwoValidReadOptionAreSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .option("relationship", "KNOWS") .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case e: IllegalArgumentException => assertEquals( "You need to specify just one of these options: 'gds', 'labels', 'query', 'relationship'", e.getMessage ) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testThrowsExceptionIfThreeValidReadOptionAreSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .option("relationship", "KNOWS") .option("query", "MATCH (n) RETURN n") .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case e: IllegalArgumentException => assertEquals( "You need to specify just one of these options: 'gds', 'labels', 'query', 'relationship'", e.getMessage ) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testReadNodeHasIdField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {name: 'John'})") /** * utnaf: Since we can't be sure we are in total isolation, and the id is generated * internally by org.neo4j.neo4j, we just check that the field is an integer and is greater * than -1 */ assertTrue(df.select("").collectAsList().get(0).getLong(0) > -1) } @Test def testReadNodeHasLabelsField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person:Customer {name: 'John'})") val result = df.select("").collectAsList().get(0).getAs[Seq[String]](0) assertEquals("Person", result.head) assertEquals("Customer", result(1)) } @Test def testReadNodeHasUnusualLabelsField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:`Foo Bar`:Person:`(╯°□°)╯︵ ┻━┻` {name: 'John'})") val result = df.select("").collectAsList().get(0).getAs[Seq[String]](0) assertEquals(Set("Person", "Foo Bar", "(╯°□°)╯︵ ┻━┻"), result.toSet[String]) } @Test def testReadNodeWithFieldWithDifferentTypes(): Unit = { val df: DataFrame = initTest("CREATE (p1:Person {id: 1, field: [12,34]}), (p2:Person {id: 2, field: 123})") val res = df.orderBy("id").collectAsList() assertEquals("[12,34]", res.get(0).get(3)) assertEquals("123", res.get(1).get(3)) } @Test def testReadNodeWithString(): Unit = { val name: String = "John" val df: DataFrame = initTest(s"CREATE (p:Person {name: '$name'})") assertEquals(name, df.select("name").collectAsList().get(0).getString(0)) } @Test def testReadNodeWithLong(): Unit = { val age: Long = 42 val df: DataFrame = initTest(s"CREATE (p:Person {age: $age})") assertEquals(age, df.select("age").collectAsList().get(0).getLong(0)) } @Test def testReadNodeWithDouble(): Unit = { val score: Double = 3.14 val df: DataFrame = initTest(s"CREATE (p:Person {score: $score})") assertEquals(score, df.select("score").collectAsList().get(0).getDouble(0), 0) } @Test def testReadNodeWithLocalTime(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {aTime: localtime({hour:12, minute: 23, second: 0, millisecond: 294})})") val result = df.select("aTime").collectAsList().get(0).getAs[GenericRowWithSchema](0) assertEquals("local-time", result.get(0)) assertEquals("12:23:00.294", result.get(1)) } @Test def testReadNodeWithTime(): Unit = { val timezone = TimeZone.getDefault val df: DataFrame = initTest(s"CREATE (p:Person {aTime: time({hour:12, minute: 23, second: 0, millisecond: 294})})") val result = df.select("aTime").collectAsList().get(0).getAs[GenericRowWithSchema](0) val localTime = LocalTime.of(12, 23, 0, 294000000) val expectedTime = OffsetTime.of(localTime, timezone.toZoneId.getRules.getOffset(Instant.now)) assertEquals("offset-time", result.get(0)) assertEquals(expectedTime.toString, result.get(1)) } @Test def testReadNodeWithLocalDateTime(): Unit = { val localDateTime = "2007-12-03T10:15:30" val df: DataFrame = initTest(s"CREATE (p:Person {aTime: localdatetime('$localDateTime')})") val result = df.select("aTime").collectAsList().get(0).getAs[LocalDateTime](0) assertEquals(LocalDateTime.parse(localDateTime), result) } @Test def testReadNodeWithZonedDateTime(): Unit = { val datetime = "2015-06-24T12:50:35.556+01:00" val df: DataFrame = initTest(s"CREATE (p:Person {aTime: datetime('$datetime')})") val result = df.select("aTime").collectAsList().get(0).getTimestamp(0) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime).toInstant), result) } @Test def testReadNodeWithPoint(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({x: 12.12, y: 13.13})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0); assertEquals("point-2d", res.get(0)) assertEquals(7203, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) } @Test def testReadNodeWithGeoPoint(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({longitude: 12.12, latitude: 13.13})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0); assertEquals("point-2d", res.get(0)) assertEquals(4326, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) } @Test def testReadNodeWithPoint3D(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({x: 12.12, y: 13.13, z: 1})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0) assertEquals("point-3d", res.get(0)) assertEquals(9157, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) assertEquals(1.0, res.get(4)) } @Test def testReadNodeWithDate(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {born: date('2009-10-10')})") val list = df.select("born").collectAsList() val res = list.get(0).getDate(0) assertEquals(java.sql.Date.valueOf("2009-10-10"), res) } @Test def testReadNodeWithDuration(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {range: duration({days: 14, hours:16, minutes: 12})})") val list = df.select("range").collectAsList() val res = list.get(0).getAs[GenericRowWithSchema](0) assertEquals("duration", res(0)) assertEquals(0L, res(1)) assertEquals(14L, res(2)) assertEquals(58320L, res(3)) assertEquals(0, res(4)) assertEquals("P0M14DT58320S", res(5)) } @Test def testReadNodeWithStringArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {names: ['John', 'Doe']})") val res = df.select("names").collectAsList().get(0).getAs[Seq[String]](0) assertEquals("John", res.head) assertEquals("Doe", res(1)) } @Test def testReadNodeWithLongArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {ages: [22, 23]})") val res = df.select("ages").collectAsList().get(0).getAs[Seq[Long]](0) assertEquals(22, res.head) assertEquals(23, res(1)) } @Test def testReadNodeWithDoubleArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {scores: [22.33, 44.55]})") val res = df.select("scores").collectAsList().get(0).getAs[Seq[Double]](0) assertEquals(22.33, res.head, 0) assertEquals(44.55, res(1), 0) } @Test def testReadNodeWithTimestampArray(): Unit = { val df: DataFrame = initTest( s"CREATE (p:Person {someTimes: [datetime('2010-10-10T11:13:37+01:00'), datetime('2011-11-11T10:13:37Z')]})" ) val res = df.select("someTimes").collectAsList().get(0).getAs[Seq[Timestamp]](0) assertEquals("2010-10-10T10:13:37Z", res.head.toInstant.atZone(ZoneOffset.UTC).toString) assertEquals("2011-11-11T10:13:37Z", res(1).toInstant.atZone(ZoneOffset.UTC).toString) } @Test def testReadNodeWithLocalTimeArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {someTimes: [localtime({hour:12}), localtime({hour:1, minute: 3})]})") val res = df.select("someTimes").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("local-time", res.head.get(0)) assertEquals("12:00:00", res.head.get(1)) assertEquals("local-time", res(1).get(0)) assertEquals("01:03:00", res(1).get(1)) } @Test def testReadNodeWithBooleanArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {bools: [true, false]})") val res = df.select("bools").collectAsList().get(0).getAs[Seq[Boolean]](0) assertEquals(true, res.head) assertEquals(false, res(1)) } @Test def testReadNodeWithPointArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {locations: [point({x: 11, y: 33.111}), point({x: 22, y: 44.222})]})") val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-2d", res.head.get(0)) assertEquals(7203, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals("point-2d", res(1).get(0)) assertEquals(7203, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) } @Test def testReadNodeWithGeoPointArray(): Unit = { val df: DataFrame = initTest( s"CREATE (p:Person {locations: [point({longitude: 11, latitude: 33.111}), point({longitude: 22, latitude: 44.222})]})" ) val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-2d", res.head.get(0)) assertEquals(4326, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals("point-2d", res(1).get(0)) assertEquals(4326, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) } @Test def testReadNodeWithPoint3DArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {locations: [point({x: 11, y: 33.111, z: 12}), point({x: 22, y: 44.222, z: 99.1})]})") val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-3d", res.head.get(0)) assertEquals(9157, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals(12.0, res.head.get(4)) assertEquals("point-3d", res(1).get(0)) assertEquals(9157, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) assertEquals(99.1, res(1).get(4)) } @Test def testReadNodeWithArrayDate(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {dates: [date('2009-10-10'), date('2009-10-11')]})") val res = df.select("dates").collectAsList().get(0).getAs[Seq[java.sql.Date]](0) assertEquals(java.sql.Date.valueOf("2009-10-10"), res.head) assertEquals(java.sql.Date.valueOf("2009-10-11"), res(1)) } @Test def testReadNodeWithArrayZonedDateTime(): Unit = { val datetime1 = "2015-06-24T12:50:35.556+01:00" val datetime2 = "2015-06-23T12:50:35.556+01:00" val df: DataFrame = initTest( s""" CREATE (p:Person {aTime: [ datetime('$datetime1'), datetime('$datetime2') ]}) """ ) val result = df.select("aTime").collectAsList().get(0).getAs[Seq[Timestamp]](0) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime1).toInstant), result.head) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime2).toInstant), result(1)) } @Test def testReadNodeWithArrayDurations(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {durations: [duration({months: 0.75}), duration({weeks: 2.5})]})") val res = df.select("durations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("duration", res.head.get(0)) assertEquals(0L, res.head.get(1)) assertEquals(22L, res.head.get(2)) assertEquals(71509L, res.head.get(3)) assertEquals(500000000, res.head.get(4)) assertEquals("P0M22DT71509.500000000S", res.head.get(5)) assertEquals("duration", res(1).get(0)) assertEquals(0L, res(1).get(1)) assertEquals(17L, res(1).get(2)) assertEquals(43200L, res(1).get(3)) assertEquals(0, res(1).get(4)) assertEquals("P0M17DT43200S", res(1).get(5)) } @Test def testReadNodeWithBinary(): Unit = { val bytes = "hello, world!".map(_.toByte).toArray val parameters = new java.util.HashMap[String, Object]() parameters.put("bytes", bytes) SparkConnectorScalaSuiteIT.session() .writeTransaction((tx: Transaction) => tx.run("CREATE (h:Hello {b: $bytes})", parameters).consume()) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Hello") .load() val res = df.select("b").collect() assertEquals(1, res.length) val gotBytes = res.head.getAs[Array[Byte]](0) for (i <- bytes.indices) { assertEquals(bytes(i), gotBytes(i)) } } @Test def testReadNodeWithEqualToFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """ ) val result = df.select("name").where("name = 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("John Doe", result.get(0).getString(0)) } @Test def testReadNodeWithEqualToDateFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {birth: date('1998-02-04')}), (p2:Person {birth: date('1988-01-05')}) """ ) val result = df.select("birth").where("birth = '1988-01-05'").collectAsList() assertEquals(1, result.size()) assertEquals(java.sql.Date.valueOf("1988-01-05"), result.get(0).getDate(0)) } @Test def testReadNodeWithTimestampGteFilter(): Unit = { val localDateTime = "2007-12-03T10:15:30" val df: DataFrame = initTest( s""" CREATE (p1:Person {birth: localdatetime('$localDateTime')}), (p2:Person {birth: localdatetime('$localDateTime')}) """ ) df.printSchema() df.show() val result = df.select("birth").where(s"birth >= '$localDateTime'").collectAsList() assertEquals(2, result.size()) } @Test def testReadNodeWithNotEqualToFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """ ) val result = df.select("name").where("NOT name = 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("Jane Doe", result.get(0).getString(0)) } @Test def testReadNodeWithNotEqualToDateFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {birth: date('1998-02-04')}), (p2:Person {birth: date('1988-01-05')}) """ ) val result = df.select("birth").where("NOT birth = '1988-01-05'").collectAsList() assertEquals(1, result.size()) assertEquals(java.sql.Date.valueOf("1998-02-04"), result.get(0).getDate(0)) } @Test def testReadNodeWithDifferentOperatorFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """ ) val result = df.select("name").where("name != 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("Jane Doe", result.get(0).getString(0)) } @Test def testReadNodeWithGtFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 19}), (p2:Person {age: 20}), (p3:Person {age: 21}) """ ) val result = df.select("age").where("age > 20").collectAsList() assertEquals(1, result.size()) assertEquals(21, result.get(0).getLong(0)) } @Test def testReadNodeWithGtDateFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {birth: date('1998-02-04')}), (p2:Person {birth: date('1988-01-05')}), (p3:Person {birth: date('1994-10-16')}) """ ) val result = df.select("birth").orderBy("birth").where("birth > '1990-01-01'").collectAsList() assertEquals(2, result.size()) assertEquals(java.sql.Date.valueOf("1994-10-16"), result.get(0).getDate(0)) assertEquals(java.sql.Date.valueOf("1998-02-04"), result.get(1).getDate(0)) } @Test def testReadNodeWithGtSpatialFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p:Person {location: point({x: 12, y: 12})}), (p2:Person {location: point({x: -6, y: -6})}) """ ) val result = df.select("location").where("location.x > 0").collectAsList() val row = result.get(0).getAs[GenericRowWithSchema](0); assertEquals(1, result.size()) assertEquals("point-2d", row.get(0)) assertEquals(7203, row.get(1)) assertEquals(12.0, row.get(2)) assertEquals(12.0, row.get(3)) } @Test def testReadNodeWithGteFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 19}), (p2:Person {age: 20}), (p3:Person {age: 21}) """ ) val result = df.select("age").orderBy("age").where("age >= 20").collectAsList() assertEquals(2, result.size()) assertEquals(20, result.get(0).getLong(0)) assertEquals(21, result.get(1).getLong(0)) } @Test def testReadNodeWithGteFilterWithProp(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {score: 19, limit: 20}), (p2:Person {score: 20, limit: 18}), (p3:Person {score: 21, limit: 12}) """ ) val result = df.select("score").orderBy("score").where("score >= limit").collectAsList() assertEquals(2, result.size()) assertEquals(20, result.get(0).getLong(0)) assertEquals(21, result.get(1).getLong(0)) } @Test def testReadNodeWithLtFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age < 40").collectAsList() assertEquals(1, result.size()) assertEquals(39, result.get(0).getLong(0)) } @Test def testReadNodeWithLteFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age <= 41").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(41, result.get(1).getLong(0)) } @Test def testReadNodeWithInFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age IN(41,43)").collectAsList() assertEquals(2, result.size()) assertEquals(41, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithIsNullFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """ ) val result = df.select("age").where("age IS NULL").collectAsList() assertEquals(1, result.size()) assertNull(result.get(0).get(0)) } @Test def testReadNodeWithIsNotNullFilter(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age IS NOT NULL").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithOrCondition(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age = 43 OR age = 39 OR age = 32").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithAndCondition(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """ ) val result = df.select("age").orderBy("age").where("age >= 39 AND age <= 43").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithStartsWith(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """ ) val result = df.select("name").orderBy("name").where("name LIKE 'John%'").collectAsList() assertEquals(3, result.size()) assertEquals("John Butler", result.get(0).getString(0)) assertEquals("John Mayer", result.get(1).getString(0)) assertEquals("John Scofield", result.get(2).getString(0)) } @Test def testReadNodeWithEndsWith(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """ ) val result = df.select("name").where("name LIKE '%Scofield'").collectAsList() assertEquals(1, result.size()) assertEquals("John Scofield", result.get(0).getString(0)) } @Test def testReadNodeWithContains(): Unit = { val df: DataFrame = initTest( s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """ ) val result = df.select("name").where("name LIKE '%ay%'").collectAsList() assertEquals(1, result.size()) assertEquals("John Mayer", result.get(0).getString(0)) } @Test def testRelFiltersWithMap(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "true") .option("relationship", "KNOWS") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Person") .load() assertEquals(1, df.filter("``.`id` = '14' AND ``.`id` = '16'").count) } @Test def testRelFiltersWithoutMap(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "KNOWS") .option("relationship.nodes.map", "false") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Person") .load() assertEquals(1, df.filter("`source.id` = 14 AND `target.id` = 16").count) } @Test def testReadRelationshipFilters(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin val df: DataFrame = initTest(fixtureQuery) val repartitionedDf = df.repartition(10) assertEquals(10, repartitionedDf.rdd.getNumPartitions) val numNode = repartitionedDf.collect().length assertEquals(100, numNode) } @Test def testRelationshipsDifferentFieldValues(): Unit = { val fixtureQuery: String = s"""CREATE (pr1:Product {id: '1'}) |CREATE (pr2:Product {id: 2}) |CREATE (pe1:Person {id: '3'}) |CREATE (pe2:Person {id: 4}) |CREATE (pe1)-[:BOUGHT]->(pr1) |CREATE (pe2)-[:BOUGHT]->(pr2) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship", "BOUGHT") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val res = df.sort("`source.id`").collectAsList() assertEquals("3", res.get(0).get(4)) assertEquals("1", res.get(0).get(7)) assertEquals("4", res.get(1).get(4)) assertEquals("2", res.get(1).get(7)) } @Test def testReadNodesCustomPartitions(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person:Customer {id: id, name: 'Person ' + id}) |RETURN * """.stripMargin val fixture2Query: String = """UNWIND range(1,100) as id |CREATE (p:Employee:Customer {id: id, name: 'Person ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixture2Query).consume() } ) val partitionedDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("partitions", "5") .load() assertEquals(5, partitionedDf.rdd.getNumPartitions) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("id")).toSet.size) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("id")).size) } @Test def testReadRelsCustomPartitions(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id: id, name: 'Person ' + id})-[:BOUGHT{quantity: ceil(rand() * 100)}]->(:Product{id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val partitionedDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "true") .option("relationship", "BOUGHT") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .option("partitions", "5") .load() assertEquals(5, partitionedDf.rdd.getNumPartitions) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("")).toSet.size) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("")).size) } @Test def testReadQueryCustomPartitions(): Unit = { val fixtureProduct1Query: String = """CREATE (pr:Product{id: 1, name: 'Product 1'}) |WITH pr |UNWIND range(1,100) as id |CREATE (p:Person {id: id, name: 'Person ' + id})-[:BOUGHT{quantity: ceil(rand() * 100)}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureProduct1Query).consume() } ) val fixtureProduct2Query: String = """CREATE (pr:Product{id: 2, name: 'Product 2'}) |WITH pr |UNWIND range(1,50) as id |MATCH (p:Person {id: id}) |CREATE (p)-[:BOUGHT{quantity: ceil(rand() * 100)}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureProduct2Query).consume() } ) val partitionedQueryCountDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |MATCH (p:Person)-[r:BOUGHT]->(pr:Product{name: 'Product 2'}) |RETURN p.name AS person, pr.name AS product, r.quantity AS quantity""".stripMargin ) .option("partitions", "5") .option( "query.count", """ |MATCH (p:Person)-[r:BOUGHT]->(pr:Product{name: 'Product 2'}) |RETURN count(p) AS count""".stripMargin ) .load() assertEquals(5, partitionedQueryCountDf.rdd.getNumPartitions) assertEquals(50, partitionedQueryCountDf.collect().map(_.getAs[String]("person")).toSet.size) assertEquals(50, partitionedQueryCountDf.collect().map(_.getAs[String]("person")).length) val partitionedQueryCountLiteralDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |MATCH (p:Person)-[r:BOUGHT]->(pr:Product{name: 'Product 2'}) |RETURN p.name AS person, pr.name AS product, r.quantity AS quantity""".stripMargin ) .option("partitions", "5") .option("query.count", "50") .load() assertEquals(5, partitionedQueryCountLiteralDf.rdd.getNumPartitions) assertEquals(50, partitionedQueryCountLiteralDf.collect().map(_.getAs[String]("person")).toSet.size) assertEquals(50, partitionedQueryCountLiteralDf.collect().map(_.getAs[String]("person")).length) } @Test def testRelationshipsFlatten(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val count = df.collectAsList() .asScala .filter(row => row.getAs[Long]("") >= 0 && row.getAs[String]("") != null && row.getAs[Double]("rel.when") >= 0 && row.getAs[Double]("rel.quantity") >= 0 && row.getAs[Long]("") >= 0 && row.getAs[Long]("source.id") >= 0 && !row.getAs[Seq[String]]("").isEmpty && row.getAs[String]("source.fullName") != null && row.getAs[Long]("") >= 0 && row.getAs[Double]("target.id") >= 0 && !row.getAs[Seq[String]]("").isEmpty && row.getAs[String]("target.name") != null ) .size assertEquals(total, count) } @Test def testRelationshipsMap(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "true") .option("relationship", "BOUGHT") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val rows = df.collectAsList().asScala val count = rows .filter(row => row.getAs[Long]("") >= 0 && row.getAs[String]("") != null && row.getAs[Double]("rel.when") >= 0 && row.getAs[Double]("rel.quantity") >= 0 && row.getAs[Map[String, String]]("") != null && row.getAs[Map[String, String]]("") != null ) .size assertEquals(total, count) val countSourceMap = rows.map(row => row.getAs[Map[String, String]]("")) .filter(row => row.keys == Set("id", "fullName", "", "")) .size assertEquals(total, countSourceMap) val countTargetMap = rows.map(row => row.getAs[Map[String, String]]("")) .filter(row => row.keys == Set("id", "name", "", "")) .size assertEquals(total, countTargetMap) } @Test def testQueries(): Unit = { val dfMap = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "RETURN {a: 1, b: '3'} AS map") .load() val map = dfMap.collect()(0).getAs[Map[String, String]]("map") val expectedMap = Map("a" -> "1", "b" -> "3") assertEquals(expectedMap, map) val dfArrayMap = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "RETURN [{a: 1, b: '3'}, {a: 'foo'}] AS listMap") .load() val listMap = dfArrayMap.collect()(0).getAs[Seq[_]]("listMap").toList val expectedListMap = Seq(Map("a" -> "1", "b" -> "3"), Map("a" -> "foo")) assertEquals(expectedListMap, listMap) val dfArray = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "RETURN [1, 'foo'] AS list") .load() val list = dfArray.collect()(0).getAs[Seq[_]]("list") val expectedList = Seq("1", "foo") assertEquals(expectedList, list) } @Test def testComplexQuery(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Person) WITH n LIMIT 2 RETURN collect(n) AS nodes") .load() val data = df.collect() val count = data.flatMap(row => row.getAs[Seq[Row]]("nodes")) .filter(row => row.getAs[Long]("") >= 0 && !row.getAs[Seq[String]]("").isEmpty && !row.getAs[String]("fullName").isEmpty && row.getAs[Long]("id") >= 0 ) .size assertEquals(2, count) val dfString: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN id(p) AS personId, id(pr) AS productId, {quantity: b.quantity, when: b.when} AS map""".stripMargin ) .option("schema.strategy", "string") .load() val dataString = dfString.collect() val countString = dataString .filter(row => !row.getAs[String]("personId").isEmpty && !row.getAs[String]("productId").isEmpty && !row.getAs[String]("map").isEmpty ) .size assertEquals(100, countString) val dfRel: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN b AS rel""".stripMargin ) .load() val dataRel = dfRel.collect() val countRel = dataRel .map(_.getAs[Row]("rel")) .filter(row => row.getAs[Long]("") >= 0 && !row.getAs[String]("").isEmpty && row.getAs[Long]("") >= 0 && row.getAs[Long]("") >= 0 && row.getAs[Double]("when") != null && row.getAs[Double]("quantity") != null ) .size assertEquals(100, countRel) } @Test def testShouldCreateTheCorrectDataframeWithTwoPartitions(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (i1:Instrument{name: 'Drums'}), (i2:Instrument{name: 'Guitar'})").consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Instrument") .option("partitions", "2") .load assertEquals(2, df.count()) } @Test def testEmptyDataset(): Unit = { val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (e:ID_DO_NOT_EXIST) RETURN id(e) as f, 1 as g") .load assertEquals(0, df.count()) assertEquals(Set("f", "g"), df.columns.toSet) } @Test def testColumnSorted(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (i1:Instrument{name: 'Drums', id: 1}), (i2:Instrument{name: 'Guitar', id: 2})").consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (i:Instrument) RETURN id(i) as internal_id, i.id as id, i.name as name, i.name") .load() .orderBy("id") val row = df.collectAsList().get(0) assertEquals(1L, row.get(1)) assertEquals("Drums", row.get(2)) assertEquals(Set("internal_id", "id", "name", "i.name"), df.columns.toSet) } @Test def testComplexReturnStatement(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN id(p) AS personId, id(pr) AS productId, {quantity: b.quantity, when: b.when} AS map, "some string" as someString, {anotherField: "201"} as map2""".stripMargin ) .option("schema.strategy", "string") .load() assertEquals(Set("personId", "productId", "map", "someString", "map2"), df.columns.toSet) assertEquals(100, df.count()) } @Test def testComplexReturnStatementNoValues(): Unit = { val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """MATCH (p:Person)-[b:BOUGHT]->(pr:Product) |RETURN id(p) AS personId, id(pr) AS productId, {quantity: b.quantity, when: b.when} AS map, "some string" as someString, {anotherField: "201", and: 1} as map2""".stripMargin ) .option("schema.strategy", "string") .load() assertEquals(Set("personId", "productId", "map", "someString", "map2"), df.columns.toSet) assertEquals(0, df.count()) } @Test def testShouldPassTheScriptResult(): Unit = { val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("script", "RETURN 'foo' AS val") .option("query", "UNWIND range(1,2) as id RETURN id AS val, scriptResult[0].val AS script") .option("partitions", 2) .option("query.count", 2) .load .orderBy("val") val data = df.collect() .map(row => (row.getAs[String]("script"), row.getAs[Long]("val"))) .toSeq val expected = Seq(("foo", 1), ("foo", 2)) assertEquals(expected, data) } @Test def testShouldFailWithExplicitErrorIfSkipLimitIsUsedAtTheEndOfTheQuery(): Unit = { try { ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Label) RETURN id(n) as id LIMIT 100") .option("partitions", 2) .option("query.count", 2) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case iae: IllegalArgumentException => { assertTrue(iae.getMessage.equals("SKIP/LIMIT are not allowed at the end of the query")) } case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testShouldFailWithExplicitErrorIfLowercaseSkipLimitIsUsedAtTheEndOfTheQuery(): Unit = { try { ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Label) RETURN id(n) as id limit 100 skip 2") .option("partitions", 2) .option("query.count", 2) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case iae: IllegalArgumentException => { assertTrue(iae.getMessage.equals("SKIP/LIMIT are not allowed at the end of the query")) } case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testShouldFailWithExplicitErrorIfRandomcaseSkipLimitIsUsedAtTheEndOfTheQuery(): Unit = { try { ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Label) RETURN id(n) as id LiMIt 100 skIp 2") .option("partitions", 2) .option("query.count", 2) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case iae: IllegalArgumentException => { assertTrue(iae.getMessage.equals("SKIP/LIMIT are not allowed at the end of the query")) } case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testShouldFailWithExplicitErrorIfSkipLimitIsUsedAtTheEndOfTheQueryWithMultilineQuery(): Unit = { try { ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", "MATCH (n:Label)\n" + "RETURN id(n) as id\n" + "LIMIT 100" ) .option("partitions", 2) .option("query.count", 2) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 org.junit.Assert.fail("Expected to throw an exception") } catch { case iae: IllegalArgumentException => { assertTrue(iae.getMessage.equals("SKIP/LIMIT are not allowed at the end of the query")) } case t: Throwable => { t.printStackTrace() fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } } @Test def testShouldAllowSkipLimitInsideTheQuery(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (p:Product) WITH p\nLIMIT 10\nRETURN p") .option("partitions", 2) .option("query.count", 20) .load assertEquals(10, df.count()) } @Test def testShouldReturnJustTheSelectedFieldWithNode(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .select("name") df.count() assertEquals(Set("name"), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithNodeAndWeirdColumnName(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, `(╯°□°)╯︵ ┻━┻`: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .select("`(╯°□°)╯︵ ┻━┻`") df.count() assertEquals(Set("(╯°□°)╯︵ ┻━┻"), df.columns.toSet) } @Test def testShouldSelectTheSystemColumnsInRelationship(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .select("``") df.collect() assertEquals(Set(""), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithRelationship(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Product") .option("relationship.target.labels", "Person") .load .select("`source.name`", "``") df.count() assertEquals(Set("source.name", ""), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithRelationshipAndWeirdColumn(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), `(╯°□°)╯︵ ┻━┻`: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .select("`target.(╯°□°)╯︵ ┻━┻`", "``") df.count() assertEquals(Set("target.(╯°□°)╯︵ ┻━┻", ""), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithQuery(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (p:Product) RETURN p.name as name") .option("partitions", 2) .option("query.count", 20) .load .select("name") df.count() assertEquals(Set("name"), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithFilter(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .filter("name = 'Product 1'") df.count() assertEquals(Set("", "", "name", "id"), df.columns.toSet) } @Test def testShouldReturnJustTheSelectedFieldWithRelationshipWithFilter(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .filter("`target.name` = 'Product 1' AND `target.id` = '16'") .select("`target.name`", "`target.id`") df.count() assertEquals(Set("target.name", "target.id"), df.columns.toSet) } @Test def testShouldUseTheUserSpecifiedSchema(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (p:Person {name: 'Foo Bar', age: 8})").consume() } ) val df = ss.read.format(classOf[DataSource].getName) .schema(StructType(Seq(StructField("age", DataTypes.StringType)).toSeq)) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Person) RETURN n.age AS age") .load() assertEquals("8", df.collect().head.get(0)) } @Test def testQueryWithOrderByShouldBeAllowed(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (p:Person {name: 'Foo Bar', age: 8})").consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (n:Person) RETURN n.age AS age ORDER by age") .load() assertEquals(8L, df.collect().head.get(0)) } @Test def testShouldLimitTheNodeResults(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction((tx: Transaction) => tx.run(fixtureQuery).consume()) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Product") .load .limit(10) assertEquals(10, df.count()) assertEquals(Set("", "", "name", "id"), df.columns.toSet) } @Test def testShouldLimitTheRelationshipResults(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction((tx: Transaction) => tx.run(fixtureQuery).consume()) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .select("`target.name`", "`target.id`") .limit(10) assertEquals(10, df.count()) assertEquals(Set("target.name", "target.id"), df.columns.toSet) } @Test def test531(): Unit = { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |UNWIND [ | {first: '2022-06-14T10:02:28.192Z', second: null}, | {first: '2022-06-15T10:02:28.192Z', second: '2022-06-16T10:02:28.192Z'} |]AS event |RETURN datetime(event.first) AS first, datetime(event.second) AS second |""".stripMargin ) .load() assertEquals( StructType(Array( StructField("first", DataTypes.TimestampType), StructField("second", DataTypes.StringType) )), dataFrame.schema ) assertEquals( List( (Timestamp.from(OffsetDateTime.parse("2022-06-14T10:02:28.192Z").toInstant), null), (Timestamp.from(OffsetDateTime.parse("2022-06-15T10:02:28.192Z").toInstant), "2022-06-16T10:02:28.192Z") ), dataFrame.collect() .map(r => (r.getTimestamp(0), r.getString(1))) .toList ) } @Test def test531WithSchema(): Unit = { val schema = StructType(Array( StructField("first", DataTypes.TimestampType), StructField("second", DataTypes.TimestampType) )) // dataFrameWithSchema must with the proper data type val dataFrameWithSchema = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |UNWIND [ | {first: '2022-06-14T10:02:28.192Z', second: null}, | {first: '2022-06-15T10:02:28.192Z', second: '2022-06-16T10:02:28.192Z'} |] AS event |RETURN datetime(event.first) AS first, datetime(event.second) AS second |""".stripMargin ) .schema(schema) .load() assertEquals( schema, dataFrameWithSchema.schema ) assertEquals( List( (Timestamp.from(OffsetDateTime.parse("2022-06-14T10:02:28.192Z").toInstant), null), ( Timestamp.from(OffsetDateTime.parse("2022-06-15T10:02:28.192Z").toInstant), Timestamp.from(OffsetDateTime.parse("2022-06-16T10:02:28.192Z").toInstant) ) ), dataFrameWithSchema.collect() .map(r => (r.getTimestamp(0), r.getTimestamp(1))) .toList ) } @Test def testShouldAggregateAndLimitTheRelationshipResults(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.session() .writeTransaction((tx: Transaction) => tx.run(fixtureQuery).consume()) val df = ss.read .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Product") .load .select("`target.name`", "`target.id`") .orderBy(col("`target.name`").desc) .limit(10) df.show() assertEquals(10, df.count()) assertEquals(Set("target.name", "target.id"), df.columns.toSet) } private def initTest(query: String): DataFrame = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(query).consume() } ) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .load() } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceReaderWithApocTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.junit.Assert._ import org.junit.Test import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import java.sql.Timestamp import java.time.Instant import java.time.LocalDateTime import java.time.LocalTime import java.time.OffsetDateTime import java.time.OffsetTime import java.time.ZoneOffset import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.ArraySeq import scala.collection.mutable.Seq class DataSourceReaderWithApocTSE extends SparkConnectorScalaBaseWithApocTSE { @Test def testReadNodeHasIdField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {name: 'John'})") /** * utnaf: Since we can't be sure we are in total isolation, and the id is generated * internally by org.neo4j.neo4j, we just check that the field is an integer and is greater * than -1 */ assertTrue(df.select("").collectAsList().get(0).getLong(0) > -1) } @Test def testReadNodeHasLabelsField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person:Customer {name: 'John'})") val result = df.select("").collectAsList().get(0).getAs[Seq[String]](0) assertEquals("Person", result.head) assertEquals("Customer", result(1)) } @Test def testReadNodeHasUnusualLabelsField(): Unit = { val df: DataFrame = initTest(s"CREATE (p:`Foo Bar`:Person {name: 'John'})") val result = df.select("").collectAsList().get(0).getAs[Seq[String]](0) assertEquals(Set("Person", "Foo Bar"), result.toSet[String]) } @Test def testReadNodeWithFieldWithDifferentTypes(): Unit = { val df: DataFrame = initTest("CREATE (p1:Person {id: 1, field: [12,34]}), (p2:Person {id: 2, field: 123})") val res = df.orderBy("id").collectAsList() assertEquals("[12,34]", res.get(0).get(3)) assertEquals("123", res.get(1).get(3)) } @Test def testReadNodeWithString(): Unit = { val name: String = "John" val df: DataFrame = initTest(s"CREATE (p:Person {name: '$name'})") assertEquals(name, df.select("name").collectAsList().get(0).getString(0)) } @Test def testReadNodeWithLong(): Unit = { val age: Long = 42 val df: DataFrame = initTest(s"CREATE (p:Person {age: $age})") assertEquals(age, df.select("age").collectAsList().get(0).getLong(0)) } @Test def testReadNodeWithDouble(): Unit = { val score: Double = 3.14 val df: DataFrame = initTest(s"CREATE (p:Person {score: $score})") assertEquals(score, df.select("score").collectAsList().get(0).getDouble(0), 0) } @Test def testReadNodeWithLocalTime(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {aTime: localtime({hour:12, minute: 23, second: 0, millisecond: 294})})") val result = df.select("aTime").collectAsList().get(0).getAs[GenericRowWithSchema](0) assertEquals("local-time", result.get(0)) assertEquals("12:23:00.294", result.get(1)) } @Test def testReadNodeWithTime(): Unit = { val timezone = TimeZone.getDefault val df: DataFrame = initTest(s"CREATE (p:Person {aTime: time({hour:12, minute: 23, second: 0, millisecond: 294})})") val result = df.select("aTime").collectAsList().get(0).getAs[GenericRowWithSchema](0) val localTime = LocalTime.of(12, 23, 0, 294000000) val expectedTime = OffsetTime.of(localTime, timezone.toZoneId.getRules.getOffset(Instant.now())) assertEquals("offset-time", result.get(0)) assertEquals(expectedTime.toString, result.get(1)) } @Test def testReadNodeWithLocalDateTime(): Unit = { val localDateTime = "2007-12-03T10:15:30" val df: DataFrame = initTest(s"CREATE (p:Person {aTime: localdatetime('$localDateTime')})") val result = df.select("aTime").collectAsList().get(0).getAs[LocalDateTime](0) assertEquals(LocalDateTime.parse(localDateTime), result) } @Test def testReadNodeWithZonedDateTime(): Unit = { val datetime = "2015-06-24T12:50:35.556+01:00" val df: DataFrame = initTest(s"CREATE (p:Person {aTime: datetime('$datetime')})") val result = df.select("aTime").collectAsList().get(0).getTimestamp(0) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime).toInstant), result) } @Test def testReadNodeWithPoint(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({x: 12.12, y: 13.13})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0); assertEquals("point-2d", res.get(0)) assertEquals(7203, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) } @Test def testReadNodeWithGeoPoint(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({longitude: 12.12, latitude: 13.13})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0); assertEquals("point-2d", res.get(0)) assertEquals(4326, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) } @Test def testReadNodeWithPoint3D(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {location: point({x: 12.12, y: 13.13, z: 1})})") val res = df.select("location").collectAsList().get(0).getAs[GenericRowWithSchema](0) assertEquals("point-3d", res.get(0)) assertEquals(9157, res.get(1)) assertEquals(12.12, res.get(2)) assertEquals(13.13, res.get(3)) assertEquals(1.0, res.get(4)) } @Test def testReadNodeWithDate(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {born: date('2009-10-10')})") val list = df.select("born").collectAsList() val res = list.get(0).getDate(0) assertEquals(java.sql.Date.valueOf("2009-10-10"), res) } @Test def testReadNodeWithDuration(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {range: duration({days: 14, hours:16, minutes: 12})})") val list = df.select("range").collectAsList() val res = list.get(0).getAs[GenericRowWithSchema](0) assertEquals("duration", res(0)) assertEquals(0L, res(1)) assertEquals(14L, res(2)) assertEquals(58320L, res(3)) assertEquals(0, res(4)) assertEquals("P0M14DT58320S", res(5)) } @Test def testReadNodeWithStringArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {names: ['John', 'Doe']})") val res = df.select("names").collectAsList().get(0).getAs[Seq[String]](0) assertEquals("John", res.head) assertEquals("Doe", res(1)) } @Test def testReadNodeWithLongArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {ages: [22, 23]})") val res = df.select("ages").collectAsList().get(0).getAs[Seq[Long]](0) assertEquals(22, res.head) assertEquals(23, res(1)) } @Test def testReadNodeWithDoubleArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {scores: [22.33, 44.55]})") val res = df.select("scores").collectAsList().get(0).getAs[Seq[Double]](0) assertEquals(22.33, res.head, 0) assertEquals(44.55, res(1), 0) } @Test def testReadNodeWithLocalTimeArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {someTimes: [localtime({hour:12}), localtime({hour:1, minute: 3})]})") val res = df.select("someTimes").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("local-time", res.head.get(0)) assertEquals("12:00:00", res.head.get(1)) assertEquals("local-time", res(1).get(0)) assertEquals("01:03:00", res(1).get(1)) } @Test def testReadNodeWithBooleanArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {bools: [true, false]})") val res = df.select("bools").collectAsList().get(0).getAs[Seq[Boolean]](0) assertEquals(true, res.head) assertEquals(false, res(1)) } @Test def testReadNodeWithPointArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {locations: [point({x: 11, y: 33.111}), point({x: 22, y: 44.222})]})") val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-2d", res.head.get(0)) assertEquals(7203, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals("point-2d", res(1).get(0)) assertEquals(7203, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) } @Test def testReadNodeWithGeoPointArray(): Unit = { val df: DataFrame = initTest( s"CREATE (p:Person {locations: [point({longitude: 11, latitude: 33.111}), point({longitude: 22, latitude: 44.222})]})" ) val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-2d", res.head.get(0)) assertEquals(4326, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals("point-2d", res(1).get(0)) assertEquals(4326, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) } @Test def testReadNodeWithPoint3DArray(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {locations: [point({x: 11, y: 33.111, z: 12}), point({x: 22, y: 44.222, z: 99.1})]})") val res = df.select("locations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("point-3d", res.head.get(0)) assertEquals(9157, res.head.get(1)) assertEquals(11.0, res.head.get(2)) assertEquals(33.111, res.head.get(3)) assertEquals(12.0, res.head.get(4)) assertEquals("point-3d", res(1).get(0)) assertEquals(9157, res(1).get(1)) assertEquals(22.0, res(1).get(2)) assertEquals(44.222, res(1).get(3)) assertEquals(99.1, res(1).get(4)) } @Test def testReadNodeWithArrayDate(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {dates: [date('2009-10-10'), date('2009-10-11')]})") val res = df.select("dates").collectAsList().get(0).getAs[Seq[java.sql.Date]](0) assertEquals(java.sql.Date.valueOf("2009-10-10"), res.head) assertEquals(java.sql.Date.valueOf("2009-10-11"), res(1)) } @Test def testReadNodeWithArrayZonedDateTime(): Unit = { val datetime1 = "2015-06-24T12:50:35.556+01:00" val datetime2 = "2015-06-23T12:50:35.556+01:00" val df: DataFrame = initTest(s""" CREATE (p:Person {aTime: [ datetime('$datetime1'), datetime('$datetime2') ]}) """) val result = df.select("aTime").collectAsList().get(0).getAs[Seq[Timestamp]](0) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime1).toInstant), result.head) assertEquals(Timestamp.from(OffsetDateTime.parse(datetime2).toInstant), result(1)) } @Test def testReadNodeWithArrayDurations(): Unit = { val df: DataFrame = initTest(s"CREATE (p:Person {durations: [duration({months: 0.75}), duration({weeks: 2.5})]})") val res = df.select("durations").collectAsList().get(0).getAs[Seq[GenericRowWithSchema]](0) assertEquals("duration", res.head.get(0)) assertEquals(0L, res.head.get(1)) assertEquals(22L, res.head.get(2)) assertEquals(71509L, res.head.get(3)) assertEquals(500000000, res.head.get(4)) assertEquals("P0M22DT71509.500000000S", res.head.get(5)) assertEquals("duration", res(1).get(0)) assertEquals(0L, res(1).get(1)) assertEquals(17L, res(1).get(2)) assertEquals(43200L, res(1).get(3)) assertEquals(0, res(1).get(4)) assertEquals("P0M17DT43200S", res(1).get(5)) } @Test def testReadNodeWithEqualToFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """) val result = df.select("name").where("name = 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("John Doe", result.get(0).getString(0)) } @Test def testReadNodeWithNotEqualToFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """) val result = df.select("name").where("NOT name = 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("Jane Doe", result.get(0).getString(0)) } @Test def testReadNodeWithDifferentOperatorFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Doe'}), (p2:Person {name: 'Jane Doe'}) """) val result = df.select("name").where("name != 'John Doe'").collectAsList() assertEquals(1, result.size()) assertEquals("Jane Doe", result.get(0).getString(0)) } @Test def testReadNodeWithGtFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 19}), (p2:Person {age: 20}), (p3:Person {age: 21}) """) val result = df.select("age").where("age > 20").collectAsList() assertEquals(1, result.size()) assertEquals(21, result.get(0).getLong(0)) } @Test def testReadNodeWithGteFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 19}), (p2:Person {age: 20}), (p3:Person {age: 21}) """) val result = df.select("age").orderBy("age").where("age >= 20").collectAsList() assertEquals(2, result.size()) assertEquals(20, result.get(0).getLong(0)) assertEquals(21, result.get(1).getLong(0)) } @Test def testReadNodeWithGteFilterWithProp(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {score: 19, limit: 20}), (p2:Person {score: 20, limit: 18}), (p3:Person {score: 21, limit: 12}) """) val result = df.select("score").orderBy("score").where("score >= limit").collectAsList() assertEquals(2, result.size()) assertEquals(20, result.get(0).getLong(0)) assertEquals(21, result.get(1).getLong(0)) } @Test def testReadNodeWithLtFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age < 40").collectAsList() assertEquals(1, result.size()) assertEquals(39, result.get(0).getLong(0)) } @Test def testReadNodeWithLteFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age <= 41").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(41, result.get(1).getLong(0)) } @Test def testReadNodeWithInFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: 41}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age IN(41,43)").collectAsList() assertEquals(2, result.size()) assertEquals(41, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithIsNullFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """) val result = df.select("age").where("age IS NULL").collectAsList() assertEquals(1, result.size()) assertNull(result.get(0).get(0)) } @Test def testReadNodeWithIsNotNullFilter(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age IS NOT NULL").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithOrCondition(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age = 43 OR age = 39 OR age = 32").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithAndCondition(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {age: 39}), (p2:Person {age: null}), (p3:Person {age: 43}) """) val result = df.select("age").orderBy("age").where("age >= 39 AND age <= 43").collectAsList() assertEquals(2, result.size()) assertEquals(39, result.get(0).getLong(0)) assertEquals(43, result.get(1).getLong(0)) } @Test def testReadNodeWithStartsWith(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """) val result = df.select("name").orderBy("name").where("name LIKE 'John%'").collectAsList() assertEquals(3, result.size()) assertEquals("John Butler", result.get(0).getString(0)) assertEquals("John Mayer", result.get(1).getString(0)) assertEquals("John Scofield", result.get(2).getString(0)) } @Test def testReadNodeWithEndsWith(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """) val result = df.select("name").where("name LIKE '%Scofield'").collectAsList() assertEquals(1, result.size()) assertEquals("John Scofield", result.get(0).getString(0)) } @Test def testReadNodeWithContains(): Unit = { val df: DataFrame = initTest(s""" CREATE (p1:Person {name: 'John Mayer'}), (p2:Person {name: 'John Scofield'}), (p3:Person {name: 'John Butler'}) """) val result = df.select("name").where("name LIKE '%ay%'").collectAsList() assertEquals(1, result.size()) assertEquals("John Mayer", result.get(0).getString(0)) } @Test def testRelFiltersWithMap(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship.nodes.map", "true") .option("relationship", "KNOWS") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Person") .load() assertEquals(1, df.filter("``.`id` = '14' AND ``.`id` = '16'").collectAsList().size()) } @Test def testRelFiltersWithoutMap(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship", "KNOWS") .option("relationship.nodes.map", "false") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Person") .load() assertEquals(1, df.filter("`source.id` = 14 AND `target.id` = 16").collectAsList().size()) } @Test def testReadNodeRepartition(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id:id,ids:[id,id]}) WITH collect(p) as people |UNWIND people as p1 |UNWIND range(1,10) as friend |WITH p1, people[(p1.id + friend) % size(people)] as p2 |CREATE (p1)-[:KNOWS]->(p2) |RETURN * """.stripMargin val df: DataFrame = initTest(fixtureQuery) val repartitionedDf = df.repartition(10) assertEquals(10, repartitionedDf.rdd.getNumPartitions) val numNode = repartitionedDf.collect().length assertEquals(100, numNode) } @Test def testRelationshipsFlatten(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val count = df.collectAsList() .asScala .filter(row => row.getAs[Long]("") >= 0 && row.getAs[String]("") != null && row.getAs[Double]("rel.when") >= 0 && row.getAs[Double]("rel.quantity") >= 0 && row.getAs[Long]("") >= 0 && row.getAs[Long]("source.id") >= 0 && !row.getAs[Seq[String]]("").isEmpty && row.getAs[String]("source.fullName") != null && row.getAs[Long]("") >= 0 && row.getAs[Double]("target.id") >= 0 && !row.getAs[Seq[String]]("").isEmpty && row.getAs[String]("target.name") != null ) .size assertEquals(total, count) } @Test def testRelationshipsMap(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * rand(), name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship", "BOUGHT") .option("relationship.nodes.map", "true") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val rows = df.collectAsList().asScala val count = rows .filter(row => row.getAs[Long]("") >= 0 && row.getAs[String]("") != null && row.getAs[Double]("rel.when") >= 0 && row.getAs[Double]("rel.quantity") >= 0 && row.getAs[Map[String, String]]("") != null && row.getAs[Map[String, String]]("") != null ) .size assertEquals(total, count) val countSourceMap = rows.map(row => row.getAs[Map[String, String]]("")) .filter(row => row.keys == Set("id", "fullName", "", "")) .size assertEquals(total, countSourceMap) val countTargetMap = rows.map(row => row.getAs[Map[String, String]]("")) .filter(row => row.keys == Set("id", "name", "", "")) .size assertEquals(total, countTargetMap) } @Test def testRelationshipsDifferentFieldValues(): Unit = { val fixtureQuery: String = s"""CREATE (pr1:Product {id: '1'}) |CREATE (pr2:Product {id: 2}) |CREATE (pe1:Person {id: '3'}) |CREATE (pe2:Person {id: 4}) |CREATE (pe1)-[:BOUGHT]->(pr1) |CREATE (pe2)-[:BOUGHT]->(pr2) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship", "BOUGHT") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val res = df.sort("`source.id`").collectAsList() assertEquals("3", res.get(0).get(4)) assertEquals("1", res.get(0).get(7)) assertEquals("4", res.get(1).get(4)) assertEquals("2", res.get(1).get(7)) } @Test def testShouldReturnSamePropertiesForNodesWithMultipleLabels(): Unit = { val fixtureQuery: String = s"""CREATE (actor:Person:Actor {name: 'Keanu Reeves', born: 1964, actor: true}) |CREATE (soccerPlayer:Person:SoccerPlayer {name: 'Zlatan Ibrahimović', born: 1981, soccerPlayer: true}) |CREATE (writer:Person:Writer {name: 'Philip K. Dick', born: 1928, writer: true}) """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("labels", "Person") .load() .sort("name") val cols = df.columns.toSeq.sorted val expectedCols = Seq("name", "born", "actor", "soccerPlayer", "writer", "", "") .sorted assertEquals(expectedCols, cols) val data = df.collect().toSeq .map(row => expectedCols.filterNot(_ == "").map(col => { row.getAs[Any](col) match { case array: Array[String] => array.toList case null => null case other: Any => other } }) ) val expectedData = Seq( Seq(ArraySeq("Person", "Actor"), true, 1964, "Keanu Reeves", null, null), Seq(ArraySeq("Person", "Writer"), null, 1928, "Philip K. Dick", null, true), Seq(ArraySeq("Person", "SoccerPlayer"), null, 1981, "Zlatan Ibrahimović", true, null) ).toBuffer assertEquals(expectedData, data) } @Test def testShouldReturnSamePropertiesForNodesWithMultipleLabelsAndDifferentValues(): Unit = { val fixtureQuery: String = s"""CREATE (:Person { prop: 25 }), |(:Person:Player { prop: "hello" }), |(:Person:Player:Weirdo { prop: true }) """.stripMargin SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val df: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("labels", "Person") .load() .sort("prop") val cols = df.columns.toSeq.sorted val expectedCols = Seq("prop", "", "") .sorted assertEquals(expectedCols, cols) val data = df.collect().toSeq .map(row => expectedCols.filterNot(_ == "").map(col => { row.getAs[Any](col) match { case array: Array[String] => array.toList case null => null case other: Any => other } }) ) val expectedData = Seq( Seq(ArraySeq("Person"), "25"), Seq(ArraySeq("Person", "Player"), "hello"), Seq(ArraySeq("Person", "Player", "Weirdo"), "true") ) assertEquals(expectedData, data) } @Test def testReadNodesCustomPartitions(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person:Customer {id: id, name: 'Person ' + id}) |RETURN * """.stripMargin val fixture2Query: String = """UNWIND range(1,100) as id |CREATE (p:Employee:Customer {id: id, name: 'Person ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) SparkConnectorScalaSuiteWithApocIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixture2Query).consume() } ) val partitionedDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("partitions", "5") .load() assertEquals(5, partitionedDf.rdd.getNumPartitions) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("id")).toSet.size) } @Test def testReadRelsCustomPartitions(): Unit = { val fixtureQuery: String = """UNWIND range(1,100) as id |CREATE (p:Person {id: id, name: 'Person ' + id})-[:BOUGHT{quantity: ceil(rand() * 100)}]->(:Product{id: id, name: 'Product ' + id}) |RETURN * """.stripMargin SparkConnectorScalaSuiteWithApocIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val partitionedDf = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("relationship.nodes.map", "true") .option("relationship", "BOUGHT") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .option("partitions", "5") .load() assertEquals(5, partitionedDf.rdd.getNumPartitions) assertEquals(100, partitionedDf.collect().map(_.getAs[Long]("")).toSet.size) } @Test def testReturnProcedure(): Unit = { val query = """RETURN apoc.convert.toSet([1,1,3]) AS foo, 'bar' AS bar |""".stripMargin val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("partitions", 1) .option("query", query) .load assertEquals(Set("foo", "bar"), df.columns.toSet) assertEquals(1, df.count()) } private def initTest(query: String): DataFrame = { SparkConnectorScalaSuiteWithApocIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(query).consume() } ) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithApocIT.server.getBoltUrl) .option("labels", "Person") .load() } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceSchemaWriterTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.DayTimeIntervalType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.YearMonthIntervalType import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Assume import org.junit.BeforeClass import org.junit.Test import org.neo4j.driver.types.IsoDuration import org.neo4j.spark.util.ConstraintsOptimizationType import org.neo4j.spark.util.Neo4jOptions import org.neo4j.spark.util.SchemaConstraintsOptimizationType import java.sql.Date import java.sql.Timestamp import java.time.LocalDate import java.time.LocalDateTime import java.time.ZoneId import java.time.ZoneOffset import java.time.ZonedDateTime import java.util.TimeZone import scala.collection.JavaConverters.iterableAsScalaIterableConverter import scala.collection.JavaConverters.mapAsScalaMapConverter import scala.math.Ordering.Implicits.infixOrderingOps object DataSourceSchemaWriterTSE { @BeforeClass def checkNeo4jVersion() { Assume.assumeTrue(TestUtil.neo4jVersion(SparkConnectorScalaSuiteIT.session()) >= Versions.NEO4J_5_13) } } class DataSourceSchemaWriterTSE extends SparkConnectorScalaBaseTSE { val timeZoneLock = "UTC" // to make TIMESTAMP_NTZ tests deterministic final private val SHOW_CONSTRAINTS_QUERY = """|SHOW CONSTRAINTS |YIELD name, type, entityType, labelsOrTypes, properties, ownedIndex, propertyType""".stripMargin final private val NODE_UNIQUENESS_SHOW_CONSTRAINTS_QUERY = """|SHOW CONSTRAINTS |YIELD name, type AS ptype, entityType, labelsOrTypes, properties, ownedIndex, propertyType |RETURN name, entityType, labelsOrTypes, properties, ownedIndex, |CASE ptype | WHEN "UNIQUENESS" THEN "NODE_PROPERTY_UNIQUENESS" | ELSE ptype |END AS type, propertyType |ORDER BY type ASC""".stripMargin final private val RELATIONSHIP_UNIQUENESS_SHOW_CONSTRAINTS_QUERY = """|SHOW CONSTRAINTS |YIELD name, type AS ptype, entityType, labelsOrTypes, properties, ownedIndex, propertyType |RETURN name, entityType, labelsOrTypes, properties, ownedIndex, |CASE ptype | WHEN "RELATIONSHIP_UNIQUENESS" THEN "RELATIONSHIP_PROPERTY_UNIQUENESS" | ELSE ptype |END AS type, propertyType |ORDER BY type ASC""".stripMargin final private val ALL_TYPES_AS_COL_NAMES = Array( "string", "int", "boolean", "float", "date", "localDateTime", "zonedDateTime", "stringArray", "intArray", "booleanArray", "floatArray", "dateArray", "localDateTimeArray", "zonedDateTimeArray" ) val sparkSession = SparkSession.builder() .master("local[*]") .appName("DataSourceWriterTSE") .config("spark.sql.session.timeZone", timeZoneLock) // to make TIMESTAMP_NTZ tests deterministic .getOrCreate() import sparkSession.implicits._ private def mapData(data: Any): Any = data match { case null => null case a: Array[_] => a.toSeq.map(mapData) case l: java.util.List[_] => l.asScala.toSeq.map(mapData) case d: LocalDate => Date.valueOf(d) case zdt: ZonedDateTime => Timestamp.from(zdt.toInstant) case any: Any => any } private val schemaOptimization = SchemaConstraintsOptimizationType.values .filterNot(_ == SchemaConstraintsOptimizationType.NONE) .mkString(",") private val nodeWithSchema = "NodeWithSchema" @Test def shouldApplySchemaForNodes(): Unit = { val (expectedNode: Map[_root_.java.lang.String, Any], df: DataFrame) = createNodesDataFrameWithNotNullColumns df .write .mode(SaveMode.Append) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", s":$nodeWithSchema") .option(Neo4jOptions.SCHEMA_OPTIMIZATION, schemaOptimization) .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( s""" |MATCH (n:$nodeWithSchema) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() assertEquals(1L, count) val expectedSchema = Seq( constraintNodeNotNull(nodeWithSchema, "boolean"), constraintNodeNotNull(nodeWithSchema, "float"), constraintNodeNotNull(nodeWithSchema, "int"), constraintNodeNotNull(nodeWithSchema, "string"), constraintNodeType(nodeWithSchema, "boolean", "BOOLEAN"), constraintNodeType(nodeWithSchema, "booleanArray", "LIST"), constraintNodeType(nodeWithSchema, "date", "DATE"), constraintNodeType(nodeWithSchema, "dateArray", "LIST"), constraintNodeType(nodeWithSchema, "float", "FLOAT"), constraintNodeType(nodeWithSchema, "floatArray", "LIST"), constraintNodeType(nodeWithSchema, "int", "INTEGER"), constraintNodeType(nodeWithSchema, "intArray", "LIST"), constraintNodeType(nodeWithSchema, "localDateTime", "LOCAL DATETIME"), constraintNodeType(nodeWithSchema, "localDateTimeArray", "LIST"), constraintNodeType(nodeWithSchema, "string", "STRING"), constraintNodeType(nodeWithSchema, "stringArray", "LIST"), constraintNodeType(nodeWithSchema, "zonedDateTime", "ZONED DATETIME"), constraintNodeType(nodeWithSchema, "zonedDateTimeArray", "LIST") ) val actualSchema = SparkConnectorScalaSuiteIT.session() .run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq assertEquals(expectedSchema, actualSchema) val actualNode = SparkConnectorScalaSuiteIT.session() .readTransaction(tx => tx.run(s"MATCH (n:$nodeWithSchema) RETURN n") .list() .asScala .map(_.get("n").asNode()) .map(_.asMap()) ) .head .asScala .mapValues(mapData) .toMap assertEquals(expectedNode, actualNode) } @Test def shouldApplySchemaAndNodeKeysForNodes(): Unit = { val (expectedNode: Map[_root_.java.lang.String, Any], df: DataFrame) = createNodesDataFrameWithNotNullColumns df.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", s":$nodeWithSchema") .option(Neo4jOptions.SCHEMA_OPTIMIZATION, schemaOptimization) .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.KEY.toString) .option("node.keys", "int,string") .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( s""" |MATCH (n:$nodeWithSchema) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() assertEquals(1L, count) val expectedSchema = Seq( constraintNodeNotNull(nodeWithSchema, "boolean"), constraintNodeNotNull(nodeWithSchema, "float"), constraintNodeNotNull(nodeWithSchema, "int"), constraintNodeNotNull(nodeWithSchema, "string"), constraintNodeType(nodeWithSchema, "boolean", "BOOLEAN"), constraintNodeType(nodeWithSchema, "booleanArray", "LIST"), constraintNodeType(nodeWithSchema, "date", "DATE"), constraintNodeType(nodeWithSchema, "dateArray", "LIST"), constraintNodeType(nodeWithSchema, "float", "FLOAT"), constraintNodeType(nodeWithSchema, "floatArray", "LIST"), constraintNodeType(nodeWithSchema, "int", "INTEGER"), constraintNodeType(nodeWithSchema, "intArray", "LIST"), constraintNodeType(nodeWithSchema, "localDateTime", "LOCAL DATETIME"), constraintNodeType(nodeWithSchema, "localDateTimeArray", "LIST"), constraintNodeType(nodeWithSchema, "string", "STRING"), constraintNodeType(nodeWithSchema, "stringArray", "LIST"), constraintNodeType(nodeWithSchema, "zonedDateTime", "ZONED DATETIME"), constraintNodeType(nodeWithSchema, "zonedDateTimeArray", "LIST"), constraintNodeKey(nodeWithSchema, Seq("int", "string")) ) val actualSchema = SparkConnectorScalaSuiteIT.session() .run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq assertEquals(expectedSchema, actualSchema) val actualNode = SparkConnectorScalaSuiteIT.session() .readTransaction(tx => tx.run(s"MATCH (n:$nodeWithSchema) RETURN n") .list() .asScala .map(_.get("n").asNode()) .map(_.asMap()) ) .head .asScala .mapValues(mapData) .toMap assertEquals(expectedNode, actualNode) } @Test def shouldApplySchemaAndNodeKeysForNodesWhenRemapped(): Unit = { val (node: Map[_root_.java.lang.String, Any], df: DataFrame) = createNodesDataFrameWithNotNullColumns df.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", s":$nodeWithSchema") .option(Neo4jOptions.SCHEMA_OPTIMIZATION, schemaOptimization) .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.KEY.toString) .option("node.keys", "int:int_prop,string:string_prop") .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( s""" |MATCH (n:$nodeWithSchema) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() assertEquals(1L, count) val expectedSchema = Seq( constraintNodeNotNull(nodeWithSchema, "boolean"), constraintNodeNotNull(nodeWithSchema, "float"), constraintNodeNotNull(nodeWithSchema, "int_prop"), constraintNodeNotNull(nodeWithSchema, "string_prop"), constraintNodeType(nodeWithSchema, "boolean", "BOOLEAN"), constraintNodeType(nodeWithSchema, "booleanArray", "LIST"), constraintNodeType(nodeWithSchema, "date", "DATE"), constraintNodeType(nodeWithSchema, "dateArray", "LIST"), constraintNodeType(nodeWithSchema, "float", "FLOAT"), constraintNodeType(nodeWithSchema, "floatArray", "LIST"), constraintNodeType(nodeWithSchema, "intArray", "LIST"), constraintNodeType(nodeWithSchema, "int_prop", "INTEGER"), constraintNodeType(nodeWithSchema, "localDateTime", "LOCAL DATETIME"), constraintNodeType(nodeWithSchema, "localDateTimeArray", "LIST"), constraintNodeType(nodeWithSchema, "stringArray", "LIST"), constraintNodeType(nodeWithSchema, "string_prop", "STRING"), constraintNodeType(nodeWithSchema, "zonedDateTime", "ZONED DATETIME"), constraintNodeType(nodeWithSchema, "zonedDateTimeArray", "LIST"), constraintNodeKey(nodeWithSchema, Seq("int_prop", "string_prop")) ) val actualSchema = SparkConnectorScalaSuiteIT.session() .run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq assertEquals(expectedSchema, actualSchema) val expectedNode = node.map { case (k, v) => if (k == "string" || k == "int") (k + "_prop", v) else (k, v) } val actualNode = SparkConnectorScalaSuiteIT.session() .readTransaction(tx => tx.run(s"MATCH (n:$nodeWithSchema) RETURN n") .list() .asScala .map(_.get("n").asNode()) .map(_.asMap()) ) .head .asScala .mapValues(mapData) .toMap assertEquals(expectedNode, actualNode) } final private def constraintNodeNotNull(node: String, prop: String): Map[String, Any] = Map( "name" -> s"spark_NODE-NOT_NULL-CONSTRAINT-$node-$prop", "type" -> "NODE_PROPERTY_EXISTENCE", "entityType" -> "NODE", "labelsOrTypes" -> Seq(node), "properties" -> Seq(prop), "ownedIndex" -> null, "propertyType" -> null ) final private def constraintNodeType(node: String, prop: String, expectedType: String): Map[String, Any] = Map( "name" -> s"spark_NODE-TYPE-CONSTRAINT-$node-$prop", "type" -> "NODE_PROPERTY_TYPE", "entityType" -> "NODE", "labelsOrTypes" -> Seq(node), "properties" -> Seq(prop), "ownedIndex" -> null, "propertyType" -> expectedType ) final private def constraintNodeKey(node: String, props: Seq[String]): Map[String, Any] = Map( "name" -> s"spark_NODE_KEY-CONSTRAINT_${node}_${props.mkString("-")}", "type" -> "NODE_KEY", "entityType" -> "NODE", "labelsOrTypes" -> Seq(node), "properties" -> props, "ownedIndex" -> s"spark_NODE_KEY-CONSTRAINT_${node}_${props.mkString("-")}", "propertyType" -> null ) private def createNodesDataFrameWithNotNullColumns: (Map[String, Any], DataFrame) = { TimeZone.setDefault(TimeZone.getTimeZone(timeZoneLock)) val row = ( "Foo", 1, false, 1.1, Date.valueOf("2023-11-22"), LocalDateTime.of(2023, 11, 22, 12, 12, 12), Timestamp.valueOf(s"2020-11-22 11:11:11.11"), Seq("Foo1", "Foo2"), Seq(1, 2), Seq(true, false), Seq(1.1, 2.2), Seq(Date.valueOf("2023-11-22"), Date.valueOf("2023-11-23")), Seq(LocalDateTime.of(2023, 11, 22, 11, 11, 11), LocalDateTime.of(2023, 11, 23, 12, 12, 12)), Seq(Timestamp.valueOf("2023-11-22 11:11:11.11"), Timestamp.valueOf("2023-11-23 12:12:12.12")) ) val data = Seq(row).toDF(ALL_TYPES_AS_COL_NAMES: _*) val expectedNode = ALL_TYPES_AS_COL_NAMES.zip(row.productIterator.toSeq).toMap val schema = StructType(data.schema.map { sf => sf.name match { case "localDateTimeArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.TimestampNTZType, false), sf.nullable) case "zonedDateTimeArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.TimestampType, false), sf.nullable) case "stringArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.StringType, false), sf.nullable) case "dateArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.DateType, false), sf.nullable) case "string" => StructField(sf.name, DataTypes.StringType, false) case _ => sf } }) val df = ss.createDataFrame(data.rdd, schema) (expectedNode, df) } @Test def shouldApplySchemaForRelationshipsAndNodes(): Unit = { val expectedMap = createDatasetForRelationships( Map( Neo4jOptions.SCHEMA_OPTIMIZATION -> schemaOptimization ) ) val count: Long = SparkConnectorScalaSuiteIT.session().run( """ |MATCH p = (:NodeA)-[:MY_REL]->(:NodeB) |RETURN count(p) |""".stripMargin ) .single() .get(0) .asLong() assertEquals(1L, count) val expected = Seq( constraintNodeNotNull("NodeA", "id"), constraintNodeNotNull("NodeB", "id"), constraintNodeType("NodeA", "id", "STRING"), constraintNodeType("NodeB", "id", "STRING"), constraintRelNotNull("boolean"), constraintRelNotNull("float"), constraintRelNotNull("int"), constraintRelType("boolean", "BOOLEAN"), constraintRelType("booleanArray", "LIST"), constraintRelType("date", "DATE"), constraintRelType("dateArray", "LIST"), constraintRelType("float", "FLOAT"), constraintRelType("floatArray", "LIST"), constraintRelType("int", "INTEGER"), constraintRelType("intArray", "LIST"), constraintRelType("localDateTime", "LOCAL DATETIME"), constraintRelType("localDateTimeArray", "LIST"), constraintRelType("string", "STRING"), constraintRelType("stringArray", "LIST"), constraintRelType("zonedDateTime", "ZONED DATETIME"), constraintRelType("zonedDateTimeArray", "LIST") ) val actual = SparkConnectorScalaSuiteIT.session() .run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq assertEquals(expected, actual) val actualMap = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (s:NodeA)-[r:MY_REL]->(t:NodeB) |RETURN s.id AS idSource, t.id AS idTarget, r |""".stripMargin ) .list() .asScala .map(r => Map("idSource" -> r.get("idSource").asString(), "idTarget" -> r.get("idTarget").asString()) ++ r.get( "r" ).asRelationship().asMap().asScala ) .head .mapValues(mapData) .toMap assertEquals(expectedMap, actualMap) } private def createDatasetForRelationships(options: Map[String, String]): Map[String, Any] = { val shouldRemap = options.contains(Neo4jOptions.RELATIONSHIP_PROPERTIES) SparkConnectorScalaSuiteIT.session() .run("CREATE (:NodeA{id: 'a'}), (:NodeB{id: 'b'})") .consume() val colNames = Array( "idSource", "idTarget" ) ++ ALL_TYPES_AS_COL_NAMES val row = ( "a", "b", "Foo", 1, false, 1.1, Date.valueOf("2023-11-22"), LocalDateTime.of(2023, 11, 22, 12, 12, 12), Timestamp.valueOf(s"2020-11-22 11:11:11.11"), Seq("Foo1", "Foo2"), Seq(1, 2), Seq(true, false), Seq(1.1, 2.2), Seq(Date.valueOf("2023-11-22"), Date.valueOf("2023-11-23")), Seq(LocalDateTime.of(2023, 11, 22, 11, 11, 11), LocalDateTime.of(2023, 11, 23, 12, 12, 12)), Seq(Timestamp.valueOf("2023-11-22 11:11:11.11"), Timestamp.valueOf("2023-11-23 12:12:12.12")) ) val data = Seq(row).toDF(colNames: _*) val schema = StructType(data.schema.map { sf => sf.name match { case "localDateTimeArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.TimestampNTZType, false), sf.nullable) case "zonedDateTimeArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.TimestampType, false), sf.nullable) case "stringArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.StringType, false), sf.nullable) case "dateArray" => StructField(sf.name, DataTypes.createArrayType(DataTypes.DateType, false), sf.nullable) case "idSource" => StructField(sf.name, DataTypes.StringType, false) case "idTarget" => StructField(sf.name, DataTypes.StringType, false) case _ => sf } }) ss.createDataFrame(data.rdd, schema) .write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "MY_REL") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":NodeA") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.node.keys", "idSource:id") .option("relationship.target.labels", ":NodeB") .option("relationship.target.node.keys", "idTarget:id") .option("relationship.target.save.mode", "Overwrite") .options(options) .save() colNames.map(c => if (shouldRemap && (c == "string" || c == "int")) c + "_prop" else c ).zip(row.productIterator.toSeq).toMap } final private def constraintRelNotNull(prop: String): Map[String, Any] = Map( "name" -> s"spark_RELATIONSHIP-NOT_NULL-CONSTRAINT-MY_REL-$prop", "type" -> "RELATIONSHIP_PROPERTY_EXISTENCE", "entityType" -> "RELATIONSHIP", "labelsOrTypes" -> Seq("MY_REL"), "properties" -> Seq(prop), "ownedIndex" -> null, "propertyType" -> null ) final private def constraintRelType(prop: String, expectedType: String) = Map( "name" -> s"spark_RELATIONSHIP-TYPE-CONSTRAINT-MY_REL-$prop", "type" -> "RELATIONSHIP_PROPERTY_TYPE", "entityType" -> "RELATIONSHIP", "labelsOrTypes" -> Seq("MY_REL"), "properties" -> Seq(prop), "ownedIndex" -> null, "propertyType" -> expectedType ) @Test def shouldApplySchemaForRelationshipsAndNodesWhenRemapped(): Unit = { val expectedMap = createDatasetForRelationships( Map( Neo4jOptions.SCHEMA_OPTIMIZATION -> schemaOptimization, Neo4jOptions.RELATIONSHIP_PROPERTIES -> ALL_TYPES_AS_COL_NAMES.map { case "string" => "string:string_prop" case "int" => "int:int_prop" case c => c }.mkString(",") ) ) val count: Long = SparkConnectorScalaSuiteIT.session().run( """ |MATCH p = (:NodeA)-[:MY_REL]->(:NodeB) |RETURN count(p) |""".stripMargin ) .single() .get(0) .asLong() assertEquals(1L, count) val expected = Seq( constraintNodeNotNull("NodeA", "id"), constraintNodeNotNull("NodeB", "id"), constraintNodeType("NodeA", "id", "STRING"), constraintNodeType("NodeB", "id", "STRING"), constraintRelNotNull("boolean"), constraintRelNotNull("float"), constraintRelNotNull("int_prop"), constraintRelType("boolean", "BOOLEAN"), constraintRelType("booleanArray", "LIST"), constraintRelType("date", "DATE"), constraintRelType("dateArray", "LIST"), constraintRelType("float", "FLOAT"), constraintRelType("floatArray", "LIST"), constraintRelType("intArray", "LIST"), constraintRelType("int_prop", "INTEGER"), constraintRelType("localDateTime", "LOCAL DATETIME"), constraintRelType("localDateTimeArray", "LIST"), constraintRelType("stringArray", "LIST"), constraintRelType("string_prop", "STRING"), constraintRelType("zonedDateTime", "ZONED DATETIME"), constraintRelType("zonedDateTimeArray", "LIST") ) val actual = SparkConnectorScalaSuiteIT.session() .run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq assertEquals(expected, actual) val actualMap = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (s:NodeA)-[r:MY_REL]->(t:NodeB) |RETURN s.id AS idSource, t.id AS idTarget, r |""".stripMargin ) .list() .asScala .map(r => Map("idSource" -> r.get("idSource").asString(), "idTarget" -> r.get("idTarget").asString()) ++ r.get( "r" ).asRelationship().asMap().asScala ) .head .mapValues(mapData) .toMap assertEquals(expectedMap, actualMap) } @Test def shouldApplyUniqueConstraintForNode(): Unit = { val total = 10 val ds = (1 to total) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.UNIQUE.toString) .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val actualConstraint = SparkConnectorScalaSuiteIT.session().run(NODE_UNIQUENESS_SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .head val expectedConstraint = Map( "name" -> "spark_NODE_UNIQUE-CONSTRAINT_Person_surname", "type" -> "NODE_PROPERTY_UNIQUENESS", "entityType" -> "NODE", "labelsOrTypes" -> Seq("Person"), "properties" -> Seq("surname"), "ownedIndex" -> "spark_NODE_UNIQUE-CONSTRAINT_Person_surname", "propertyType" -> null ) assertEquals(expectedConstraint, actualConstraint) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT `spark_NODE_UNIQUE-CONSTRAINT_Person_surname`").consume() } @Test def shouldApplyNodeKeyConstraintForNode(): Unit = { val total = 10 val ds = (1 to total) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.KEY.toString) .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val actualConstraint = SparkConnectorScalaSuiteIT.session().run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .head val expectedConstraint = Map( "name" -> "spark_NODE_KEY-CONSTRAINT_Person_surname", "type" -> "NODE_KEY", "entityType" -> "NODE", "labelsOrTypes" -> Seq("Person"), "properties" -> Seq("surname"), "ownedIndex" -> "spark_NODE_KEY-CONSTRAINT_Person_surname", "propertyType" -> null ) assertEquals(expectedConstraint, actualConstraint) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT `spark_NODE_KEY-CONSTRAINT_Person_surname`").consume() } @Test def shouldApplyAppropriateConstraintsEvenWhenRemapped(): Unit = { val total = 10 val ds = (1 to total) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":SurnameKey") .option("node.keys", "surname:surname_key") .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.KEY.toString) .save() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":SurnameUnique") .option("node.keys", "surname:surname_unique") .option(Neo4jOptions.SCHEMA_OPTIMIZATION_NODE_KEY, ConstraintsOptimizationType.UNIQUE.toString) .save() val actualConstraint = SparkConnectorScalaSuiteIT.session().run(NODE_UNIQUENESS_SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .toSeq val expectedConstraint = Seq( Map( "name" -> "spark_NODE_KEY-CONSTRAINT_SurnameKey_surname_key", "type" -> "NODE_KEY", "entityType" -> "NODE", "labelsOrTypes" -> Seq("SurnameKey"), "properties" -> Seq("surname_key"), "ownedIndex" -> "spark_NODE_KEY-CONSTRAINT_SurnameKey_surname_key", "propertyType" -> null ), Map( "name" -> "spark_NODE_UNIQUE-CONSTRAINT_SurnameUnique_surname_unique", "type" -> "NODE_PROPERTY_UNIQUENESS", "entityType" -> "NODE", "labelsOrTypes" -> Seq("SurnameUnique"), "properties" -> Seq("surname_unique"), "ownedIndex" -> "spark_NODE_UNIQUE-CONSTRAINT_SurnameUnique_surname_unique", "propertyType" -> null ) ) assertEquals(expectedConstraint, actualConstraint) SparkConnectorScalaSuiteIT.session().run( "DROP CONSTRAINT `spark_NODE_KEY-CONSTRAINT_SurnameKey_surname_key`" ).consume() SparkConnectorScalaSuiteIT.session().run( "DROP CONSTRAINT `spark_NODE_UNIQUE-CONSTRAINT_SurnameUnique_surname_unique`" ).consume() } @Test def shouldApplyUniqueConstraintForRelationship(): Unit = { val expectedMap = createDatasetForRelationships( Map( Neo4jOptions.SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY -> ConstraintsOptimizationType.UNIQUE.toString, "relationship.keys" -> "string,int" ) ) val actualMap = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (s:NodeA)-[r:MY_REL]->(t:NodeB) |RETURN s.id AS idSource, t.id AS idTarget, r |""".stripMargin ) .list() .asScala .map(r => Map("idSource" -> r.get("idSource").asString(), "idTarget" -> r.get("idTarget").asString()) ++ r.get( "r" ).asRelationship().asMap().asScala ) .head .mapValues(mapData) .toMap assertEquals(expectedMap, actualMap) val actualConstraint = SparkConnectorScalaSuiteIT.session().run(RELATIONSHIP_UNIQUENESS_SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .head val expectedConstraint = Map( "name" -> "spark_RELATIONSHIP_UNIQUE-CONSTRAINT_MY_REL_string-int", "type" -> "RELATIONSHIP_PROPERTY_UNIQUENESS", "entityType" -> "RELATIONSHIP", "labelsOrTypes" -> Seq("MY_REL"), "properties" -> Seq("string", "int"), "ownedIndex" -> "spark_RELATIONSHIP_UNIQUE-CONSTRAINT_MY_REL_string-int", "propertyType" -> null ) assertEquals(expectedConstraint, actualConstraint) } @Test def shouldApplyRelUniqueConstraintForRelationship(): Unit = { val expectedMap = createDatasetForRelationships( Map( Neo4jOptions.SCHEMA_OPTIMIZATION_RELATIONSHIP_KEY -> ConstraintsOptimizationType.KEY.toString, "relationship.keys" -> "string,int" ) ) val actualMap = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (s:NodeA)-[r:MY_REL]->(t:NodeB) |RETURN s.id AS idSource, t.id AS idTarget, r |""".stripMargin ) .list() .asScala .map(r => Map("idSource" -> r.get("idSource").asString(), "idTarget" -> r.get("idTarget").asString()) ++ r.get( "r" ).asRelationship().asMap().asScala ) .head .mapValues(mapData) .toMap assertEquals(expectedMap, actualMap) val actualConstraint = SparkConnectorScalaSuiteIT.session().run(SHOW_CONSTRAINTS_QUERY) .list() .asScala .map(_.asMap(v => v.asObject()).asScala.mapValues(mapData).toMap) .head val expectedConstraint = Map( "name" -> "spark_RELATIONSHIP_KEY-CONSTRAINT_MY_REL_string-int", "type" -> "RELATIONSHIP_KEY", "entityType" -> "RELATIONSHIP", "labelsOrTypes" -> Seq("MY_REL"), "properties" -> Seq("string", "int"), "ownedIndex" -> "spark_RELATIONSHIP_KEY-CONSTRAINT_MY_REL_string-int", "propertyType" -> null ) assertEquals(expectedConstraint, actualConstraint) } @Test def shouldWriteNodeWithLegacyTypeConversionDisabledByDefault(): Unit = { val df = sparkSession.sql( """ |SELECT | 'legacy-type-conversion' AS id, | timestamp('2025-01-01 11:11:11') AS timestamp, | CAST('2025-01-01 11:11:11' AS TIMESTAMP_NTZ) AS timestampNtz, | INTERVAL '4' DAY AS dayInterval, | INTERVAL '10 05' DAY TO HOUR AS dayToHour, | timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00') AS arithmeticDuration, | INTERVAL '3' YEAR AS yearInterval, | INTERVAL '1-2' YEAR TO MONTH AS yearToMonth, | CAST('erik' AS BINARY) AS binary, | CAST(array(1, 2, 3) AS array) AS byteArray |""".stripMargin ) assertTrue(df.schema("dayInterval").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("dayToHour").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("arithmeticDuration").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("yearInterval").dataType.isInstanceOf[YearMonthIntervalType]) assertTrue(df.schema("yearToMonth").dataType.isInstanceOf[YearMonthIntervalType]) df.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":LegacyTypeConversionDisabled") .option("node.keys", "id") .save() val actual = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (n:LegacyTypeConversionDisabled {id: 'legacy-type-conversion'}) |RETURN n |""".stripMargin ) .single() .get("n") .asNode() .asMap() .asScala .toMap assertEquals( Set( "id", "timestamp", "timestampNtz", "dayInterval", "dayToHour", "arithmeticDuration", "yearInterval", "yearToMonth", "binary", "byteArray" ), actual.keySet ) assertEquals("legacy-type-conversion", actual("id")) val expectedZoned = ZonedDateTime.of(2025, 1, 1, 11, 11, 11, 0, ZoneOffset.UTC) assertEquals(expectedZoned, actual("timestamp")) val expectedNtz = LocalDateTime.of(2025, 1, 1, 11, 11, 11) assertEquals(expectedNtz, actual("timestampNtz")) val dayInterval = actual("dayInterval").asInstanceOf[IsoDuration] assertEquals(0L, dayInterval.months()) assertEquals(4L, dayInterval.days()) assertEquals(0L, dayInterval.seconds()) assertEquals(0, dayInterval.nanoseconds()) val dayToHour = actual("dayToHour").asInstanceOf[IsoDuration] assertEquals(0L, dayToHour.months()) assertEquals(10L, dayToHour.days()) assertEquals(18000L, dayToHour.seconds()) assertEquals(0, dayToHour.nanoseconds()) val arithmetic = actual("arithmeticDuration").asInstanceOf[IsoDuration] assertEquals(0L, arithmetic.months()) assertEquals(367L, arithmetic.days()) assertEquals(66600L, arithmetic.seconds()) assertEquals(454000000, arithmetic.nanoseconds()) val yearInterval = actual("yearInterval").asInstanceOf[IsoDuration] assertEquals(36L, yearInterval.months()) assertEquals(0L, yearInterval.days()) assertEquals(0L, yearInterval.seconds()) assertEquals(0, yearInterval.nanoseconds()) val yearToMonth = actual("yearToMonth").asInstanceOf[IsoDuration] assertEquals(14L, yearToMonth.months()) assertEquals(0L, yearToMonth.days()) assertEquals(0L, yearToMonth.seconds()) assertEquals(0, yearToMonth.nanoseconds()) assertArrayEquals(Array[Byte](101, 114, 105, 107), actual("binary").asInstanceOf[Array[Byte]]) assertArrayEquals(Array[Byte](1, 2, 3), actual("byteArray").asInstanceOf[Array[Byte]]) } @Test def shouldWriteNodeWithLegacyTypeConversionEnabled(): Unit = { val df = sparkSession.sql( """ |SELECT | 'legacy-type-conversion' AS id, | timestamp('2025-01-01 11:11:11') AS timestamp, | CAST('2025-01-01 11:11:11' AS TIMESTAMP_NTZ) AS timestampNtz, | INTERVAL '4' DAY AS dayInterval, | INTERVAL '10 05' DAY TO HOUR AS dayToHour, | timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00') AS arithmeticDuration, | INTERVAL '3' YEAR AS yearInterval, | INTERVAL '1-2' YEAR TO MONTH AS yearToMonth, | CAST('erik' AS BINARY) AS binary, | CAST(array(1, 2, 3) AS array) AS byteArray |""".stripMargin ) assertTrue(df.schema("dayInterval").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("dayToHour").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("arithmeticDuration").dataType.isInstanceOf[DayTimeIntervalType]) assertTrue(df.schema("yearInterval").dataType.isInstanceOf[YearMonthIntervalType]) assertTrue(df.schema("yearToMonth").dataType.isInstanceOf[YearMonthIntervalType]) df.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":LegacyTypeConversionEnabled") .option("node.keys", "id") .option(Neo4jOptions.TYPE_CONVERSION, "legacy") .save() val actual = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (n:LegacyTypeConversionEnabled {id: 'legacy-type-conversion'}) |RETURN n |""".stripMargin ) .single() .get("n") .asNode() .asMap() .asScala .toMap assertEquals( Set( "id", "timestamp", "timestampNtz", "dayInterval", "dayToHour", "arithmeticDuration", "yearInterval", "yearToMonth", "binary", "byteArray" ), actual.keySet ) assertEquals("legacy-type-conversion", actual("id")) val expectedTimestamp = ZonedDateTime.of(2025, 1, 1, 11, 11, 11, 0, ZoneOffset.UTC) .withZoneSameInstant(ZoneId.systemDefault()) .toLocalDateTime assertEquals(expectedTimestamp, actual("timestamp")) assertEquals( DateTimeUtils.localDateTimeToMicros(LocalDateTime.of(2025, 1, 1, 11, 11, 11)), actual("timestampNtz").asInstanceOf[java.lang.Number].longValue() ) assertEquals(4L * 24L * 3600L * 1000000L, actual("dayInterval").asInstanceOf[java.lang.Number].longValue()) assertEquals((10L * 24L + 5L) * 3600L * 1000000L, actual("dayToHour").asInstanceOf[java.lang.Number].longValue()) assertEquals( (367L * 24L * 3600L + 66600L) * 1000000L + 454000L, actual("arithmeticDuration").asInstanceOf[java.lang.Number].longValue() ) assertEquals(36L, actual("yearInterval").asInstanceOf[java.lang.Number].longValue()) assertEquals(14L, actual("yearToMonth").asInstanceOf[java.lang.Number].longValue()) assertArrayEquals(Array[Byte](101, 114, 105, 107), actual("binary").asInstanceOf[Array[Byte]]) val byteArray = actual("byteArray").asInstanceOf[java.util.List[_]].asScala .map(_.asInstanceOf[java.lang.Number].longValue()) assertEquals(Seq(1L, 2L, 3L), byteArray) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceStreamingReaderTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.Row import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.streaming.Trigger import org.hamcrest.Matchers import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder import org.neo4j.Closeables.use import org.neo4j.spark.SparkConnectorScalaSuiteIT.session import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import scala.annotation.meta.getter class DataSourceStreamingReaderTSE extends SparkConnectorScalaBaseTSE { @(Rule @getter) val folder: TemporaryFolder = new TemporaryFolder() private var query: StreamingQuery = _ @After def close(): Unit = { if (query != null) { query.stop() } } @Test def testReadStreamWithLabels(): Unit = { createMovieNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Movie") .option("streaming.property.name", "timestamp") .option("streaming.from", "NOW") .load() query = stream.writeStream .format("memory") .queryName("readStreamWithLabels") .start() val total = 60 val expected: Seq[Map[String, Any]] = (1 to total).map(index => Map( "" -> Seq("Movie"), "title" -> s"My movie $index" ) ) // Continue creating nodes in the background Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createMovieNodes(1, total, 1000, 200) } }) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithLabels order by timestamp", mapMovie) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithLabelsGetAll(): Unit = { createMovieNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Movie") .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .load() query = stream.writeStream .format("memory") .queryName("readStreamWithLabelsAll") .start() val total = 60 Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createMovieNodes(1, total, 1000, 200) } }) val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "" -> Seq("Movie"), "title" -> s"My movie $index" ) ) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithLabelsAll order by timestamp", mapMovie) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithLabelsResumesFromCheckpoint(): Unit = { createMovieNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Movie") .option("streaming.property.name", "timestamp") .option("streaming.from", "NOW") .load() val total = 60 val expected: Seq[Map[String, Any]] = (1 to total).map(index => Map( "" -> List("Movie"), "title" -> s"My movie $index" ) ) val checkpoint = folder.newFolder() stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithLabelsCheckpoint") .awaitTermination() val partial: Int = total / 2 // create partial movies starting from 1 createMovieNodes(1, partial, 0, 10) // fetch whatever is available stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithLabelsCheckpoint") .awaitTermination() // create rest of the movies starting from partial+1 createMovieNodes(partial + 1, total - partial, 0, 10) // fetch rest of the items from where we left off stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithLabelsCheckpoint") .awaitTermination() assertEquals( expected, selectRowsFromTable("select * from readStreamWithLabelsCheckpoint order by timestamp", mapMovie) ) } @Test def testReadStreamWithRelationship(): Unit = { createLikesRelationships(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "LIKES") .option("streaming.property.name", "timestamp") .option("streaming.from", "NOW") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Post") .load() query = stream.writeStream .format("memory") .queryName("readStreamWithRelationship") .start() val total = 60 val expected: Seq[Map[String, Any]] = (1 to total).map(index => Map( "" -> "LIKES", "" -> Seq("Person"), "source.age" -> index, "" -> Seq("Post"), "target.hash" -> s"hash$index", "rel.id" -> index ) ) Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createLikesRelationships(1, total, 1000, 200) } }) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithRelationship order by `rel.timestamp`", mapLikes) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithRelationshipGetAll(): Unit = { createLikesRelationships(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "LIKES") .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Post") .load() query = stream.writeStream .format("memory") .queryName("readStreamWithRelationshipAll") .start() val total = 60 val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "" -> "LIKES", "" -> Seq("Person"), "source.age" -> index, "" -> Seq("Post"), "target.hash" -> s"hash$index", "rel.id" -> index ) ) Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createLikesRelationships(1, total, 1000, 200) } }) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithRelationshipAll order by `rel.timestamp`", mapLikes) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithRelationshipResumesFromCheckpoint(): Unit = { createLikesRelationships(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "LIKES") .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option("relationship.source.labels", "Person") .option("relationship.target.labels", "Post") .load() val total = 60 val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "" -> "LIKES", "" -> Seq("Person"), "source.age" -> index, "" -> Seq("Post"), "target.hash" -> s"hash$index", "rel.id" -> index ) ) val checkpoint = folder.newFolder() stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithRelationshipCheckpoint") .awaitTermination() val partial: Int = total / 2 // create partial number of likes starting from 1 createLikesRelationships(1, partial, 0, 10) // fetch whatever is available stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithRelationshipCheckpoint") .awaitTermination() // create rest of the likes starting from partial+1 createLikesRelationships(partial + 1, total - partial, 0, 10) // fetch rest of the items from where we left off stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithRelationshipCheckpoint") .awaitTermination() assertEquals( expected, selectRowsFromTable("select * from readStreamWithRelationshipCheckpoint order by `rel.timestamp`", mapLikes) ) } @Test def testReadStreamWithQuery(): Unit = { createPersonNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("streaming.from", "NOW") .option("streaming.property.name", "timestamp") .option( "query", """ |MATCH (p:Person) |WHERE p.timestamp > $stream.offset |RETURN p.age AS age, p.timestamp AS timestamp |""".stripMargin ) .option( "streaming.query.offset", """ |MATCH (p:Person) |RETURN max(p.timestamp) |""".stripMargin ) .load() query = stream.writeStream .format("memory") .queryName("readStreamWithQuery") .start() val total = 60 val expected: Seq[Map[String, Any]] = (1 to total).map(index => Map( "age" -> s"$index" ) ) Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createPersonNodes(1, total, 1000, 200) } }) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithQuery order by timestamp", mapPerson) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithQueryGetAll(): Unit = { createPersonNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option( "query", """ |MATCH (p:Person) |WHERE p.timestamp > $stream.offset |RETURN p.age AS age, p.timestamp AS timestamp |""".stripMargin ) .option( "streaming.query.offset", """ |MATCH (p:Person) |RETURN max(p.timestamp) |""".stripMargin ) .load() query = stream.writeStream .format("memory") .queryName("readStreamWithQueryAll") .start() val total = 60 val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "age" -> s"$index" ) ).toList Executors.newSingleThreadExecutor().submit(new Runnable { override def run(): Unit = { createPersonNodes(1, total, 1000, 200) } }) Assert.assertEventually( new Assert.ThrowingSupplier[Seq[Map[String, Any]], Exception] { override def get(): Seq[Map[String, Any]] = { selectRowsFromTable("select * from readStreamWithQueryAll order by timestamp", mapPerson) } }, Matchers.equalTo(expected), 30L, TimeUnit.SECONDS ) } @Test def testReadStreamWithQueryResumesFromCheckpoint(): Unit = { createPersonNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option( "query", """ |MATCH (p:Person) |WHERE p.timestamp > $stream.offset |RETURN p.age AS age, p.timestamp AS timestamp |""".stripMargin ) .option( "streaming.query.offset", """ |MATCH (p:Person) |RETURN max(p.timestamp) |""".stripMargin ) .load() val total = 60 val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "age" -> s"$index" ) ).toList val checkpoint = folder.newFolder() stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpoint") .awaitTermination() val partial: Int = total / 2 // create partial number of persons starting from 1 createPersonNodes(1, partial, 0, 10) // fetch whatever is available stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpoint") .awaitTermination() // create rest of the persons starting from partial+1 createPersonNodes(partial + 1, total - partial, 0, 10) // fetch rest of the items from where we left off stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpoint") .awaitTermination() assertEquals( expected, selectRowsFromTable("select * from readStreamWithQueryCheckpoint order by timestamp", mapPerson) ) } @Test def testReadStreamWithQueryResumesFromCheckpointWithNewParams(): Unit = { createPersonNodes(0, 1) val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option( "query", """ |MATCH (p:Person) |WHERE p.timestamp > $stream.from AND p.timestamp <= $stream.to |RETURN p.age AS age, p.timestamp AS timestamp |""".stripMargin ) .option( "streaming.query.offset", """ |MATCH (p:Person) |RETURN max(p.timestamp) |""".stripMargin ) .load() val total = 60 val expected: Seq[Map[String, Any]] = (0 to total).map(index => Map( "age" -> s"$index" ) ).toList val checkpoint = folder.newFolder() stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpointNewParams") .awaitTermination() val partial: Int = total / 2 // create partial number of persons starting from 1 createPersonNodes(1, partial, 0, 10) // fetch whatever is available stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpointNewParams") .awaitTermination() // create rest of the persons starting from partial+1 createPersonNodes(partial + 1, total - partial, 0, 10) // fetch rest of the items from where we left off stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("readStreamWithQueryCheckpointNewParams") .awaitTermination() assertEquals( expected, selectRowsFromTable("select * from readStreamWithQueryCheckpointNewParams order by timestamp", mapPerson) ) } @Test def testStreamDoesNotReadAnyDataIfStreamingQueryReturnsNothing(): Unit = { createPersonNodes(0, 50) use(session()) { session => session.run("MATCH (p:Person) SET p:Human").consume() } val stream = ss.readStream.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("streaming.property.name", "timestamp") .option("streaming.from", "ALL") .option( "query", """ |MATCH (p:Person) |WHERE p.timestamp > $stream.from AND p.timestamp <= $stream.to |RETURN p.age AS age, p.timestamp AS timestamp |""".stripMargin ) .option("streaming.query.offset", "MATCH (p:Human) RETURN max(p.timestamp)") .load() val checkpoint = folder.newFolder() // 1st trigger: every node was processed var streamTable = stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("testStreamDoesNotReadAnyDataIfStreamingQueryReturnsNothing") streamTable.awaitTermination() val firstSource = streamTable.lastProgress.sources.head assertEquals(50, firstSource.numInputRows) val endOffset1 = firstSource.endOffset use(session()) { session => session.run("MATCH (p:Human) REMOVE p:Human").consume() } // 2nd trigger: previous offset is kept intact streamTable = stream.writeStream .trigger(Trigger.AvailableNow()) .option("checkpointLocation", checkpoint.getAbsolutePath) .toTable("testStreamDoesNotReadAnyDataIfStreamingQueryReturnsNothing") streamTable.awaitTermination() val source = streamTable.lastProgress.sources.head assertEquals(endOffset1, source.startOffset) assertEquals(endOffset1, source.endOffset) assertEquals(0L, source.numInputRows) } private def createPersonNodes(minAge: Int, maxAge: Int, delayMs: Int = 0, intervalMs: Int = 0): Unit = { use(session()) { session => Thread.sleep(delayMs) (minAge until minAge + maxAge).foreach(age => { Thread.sleep(intervalMs) session.run( s"CREATE (p:Person {age: '$age', timestamp: timestamp()})" ).consume() }) } } private def createMovieNodes(from: Int, count: Int, delayMs: Int = 0, intervalMs: Int = 0): Unit = { use(session()) { session => Thread.sleep(delayMs) (from until from + count).foreach(index => { Thread.sleep(intervalMs) session.run( s"CREATE (n:Movie {title: 'My movie $index', timestamp: timestamp()})" ).consume() }) } } private def createLikesRelationships(from: Int, count: Int, delayMs: Int = 0, intervalMs: Int = 0): Unit = { use(session()) { session => Thread.sleep(delayMs) (from until from + count).foreach(index => { Thread.sleep(intervalMs) session.run( s""" |CREATE (person:Person {age: $index}) |CREATE (post:Post {hash: "hash$index"}) |CREATE (person)-[:LIKES{id: $index, timestamp: timestamp()}]->(post) |""".stripMargin ).consume() }) } } private def selectRowsFromTable( query: String, mapper: (Row) => Map[String, Any] ): Seq[Map[String, Any]] = { ss.sql(query) .collect() .map(row => mapper(row)) .toList } private def mapPerson(row: Row): Map[String, Any] = { Map( "age" -> row.getAs[String]("age") ) } private def mapMovie(row: Row): Map[String, Any] = { Map( "" -> row.getAs[java.util.List[String]](""), "title" -> row.getAs[String]("title") ) } private def mapLikes(row: Row): Map[String, Any] = { Map( "" -> row.getAs[String](""), "" -> row.getAs[java.util.List[String]](""), "source.age" -> row.getAs[Long]("source.age"), "" -> row.getAs[java.util.List[String]](""), "target.hash" -> row.getAs[String]("target.hash"), "rel.id" -> row.getAs[Long]("rel.id") ) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceStreamingWriterTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamingQuery import org.hamcrest.Matchers import org.junit.After import org.junit.Test import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.spark.Assert.ThrowingSupplier import java.util.UUID import java.util.concurrent.TimeUnit class DataSourceStreamingWriterTSE extends SparkConnectorScalaBaseTSE { private var query: StreamingQuery = null @After def close(): Unit = { if (query != null) { query.stop() } } @Test def testSinkStreamWithLabelsWithAppend(): Unit = { implicit val ctx = ss.sqlContext import ss.implicits._ val memStream = MemoryStream[Int] val recordSize = 2000 val partition = 5 val checkpointLocation = "/tmp/checkpoint/" + UUID.randomUUID().toString query = memStream.toDF().writeStream .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("save.mode", "Append") .option("labels", "Timestamp") .option("checkpointLocation", checkpointLocation) .option("node.keys", "value") .start() (1 to partition).foreach(index => { // we send the total of records in 5 times val start = ((index - 1) * recordSize) + 1 val end = index * recordSize memStream.addData((start to end).toArray) }) Assert.assertEventually( new ThrowingSupplier[Boolean, Exception] { override def get(): Boolean = { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Timestamp") .load() val collect = dataFrame.collect() val data = if (dataFrame.columns.contains("value")) { collect .map(row => row.getAs[Long]("value").toInt) .sorted } else { Array.empty[Int] } data.toList == (1 to (recordSize * partition)).toList } }, Matchers.equalTo(true), 30L, TimeUnit.SECONDS ) } @Test def testSinkStreamWithRelationshipWithAppend(): Unit = { implicit val ctx = ss.sqlContext import ss.implicits._ val memStream = MemoryStream[Int] val recordSize = 2000 val partition = 5 val checkpointLocation = "/tmp/checkpoint/" + UUID.randomUUID().toString query = memStream.toDF().writeStream .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("save.mode", "Append") .option("relationship", "PAIRS") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":From") .option("relationship.source.node.keys", "value") .option("relationship.source.save.mode", "Append") .option("relationship.target.labels", ":To") .option("relationship.target.node.keys", "value") .option("relationship.target.save.mode", "Append") .option("checkpointLocation", checkpointLocation) .start() (1 to partition).foreach(index => { // we send the total of records in 5 times val start = ((index - 1) * recordSize) + 1 val end = index * recordSize memStream.addData((start to end).toArray) }) Assert.assertEventually( new ThrowingSupplier[Boolean, Exception] { override def get(): Boolean = try { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PAIRS") .option("relationship.source.labels", ":From") .option("relationship.target.labels", ":To") .load() val collect = dataFrame.collect() val data = if (dataFrame.columns.contains("source.value") && dataFrame.columns.contains("target.value")) { collect .map(row => (row.getAs[Long]("source.value").toInt, row.getAs[Long]("target.value").toInt)) .sorted } else { Array.empty[(Int, Int)] } data.toList == (1 to (recordSize * partition)).map(v => (v, v)).toList } catch { case _: Throwable => false } }, Matchers.equalTo(true), 30L, TimeUnit.SECONDS ) } @Test def testSinkStreamWithQuery(): Unit = { implicit val ctx = ss.sqlContext import ss.implicits._ val memStream = MemoryStream[Int] val recordSize = 2000 val partition = 5 val checkpointLocation = "/tmp/checkpoint/" + UUID.randomUUID().toString query = memStream.toDF().writeStream .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MERGE (m:MyNewNode {the_value: event.value})") .option("checkpointLocation", checkpointLocation) .start() (1 to partition).foreach(index => { // we send the total of records in 5 times val start = ((index - 1) * recordSize) + 1 val end = index * recordSize memStream.addData((start to end).toArray) }) Assert.assertEventually( new ThrowingSupplier[Boolean, Exception] { override def get(): Boolean = try { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "MyNewNode") .load() val collect = dataFrame.collect() val data = if (dataFrame.columns.contains("the_value")) { collect .map(row => row.getAs[Long]("the_value").toInt) .sorted } else { Array.empty[Int] } val l1 = data.toList val l2 = (1 to (recordSize * partition)).map(v => v).toList l1 == l2 } catch { case _: Throwable => false } }, Matchers.equalTo(true), 30L, TimeUnit.SECONDS ) } @Test def testSinkStreamWithLabelsWithOverwrite(): Unit = { implicit val ctx = ss.sqlContext import ss.implicits._ val memStream = MemoryStream[Int] val partition = 5 val checkpointLocation = "/tmp/checkpoint/" + UUID.randomUUID().toString SparkConnectorScalaSuiteIT.session().run( "CREATE CONSTRAINT timestamp_value FOR (t:Timestamp) REQUIRE (t.value) IS UNIQUE" ) query = memStream.toDF().writeStream .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("save.mode", "Overwrite") .option("labels", "Timestamp") .option("checkpointLocation", checkpointLocation) .option("node.keys", "value") .start() (1 to partition).foreach(index => { memStream.addData((1 to 500).toArray) }) Assert.assertEventually( new ThrowingSupplier[Boolean, Exception] { override def get(): Boolean = { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Timestamp") .load() val collect = dataFrame.collect() val data = if (dataFrame.columns.contains("value")) { collect .map(row => row.getAs[Long]("value").toInt) .sorted } else { Array.empty[Int] } data.toList == (1 to 500).toList } }, Matchers.equalTo(true), 30L, TimeUnit.SECONDS ) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT timestamp_value") } @Test def testSinkStreamWithRelationshipWithAppendAndOverwrite(): Unit = { implicit val ctx = ss.sqlContext import ss.implicits._ val memStream = MemoryStream[Int] val partition = 5 val checkpointLocation = "/tmp/checkpoint/" + UUID.randomUUID().toString SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("CREATE CONSTRAINT From_value FOR (p:From) REQUIRE p.value IS UNIQUE") tx.run("CREATE CONSTRAINT To_value FOR (p:To) REQUIRE p.value IS UNIQUE") } } ) query = memStream.toDF().writeStream .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("save.mode", "Append") .option("relationship", "PAIRS") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":From") .option("relationship.source.node.keys", "value") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.labels", ":To") .option("relationship.target.node.keys", "value") .option("relationship.target.save.mode", "Overwrite") .option("checkpointLocation", checkpointLocation) .start() (1 to partition).foreach(index => { memStream.addData((1 to 500).toArray) }) Assert.assertEventually( new ThrowingSupplier[Boolean, Exception] { override def get(): Boolean = try { val dataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PAIRS") .option("relationship.source.labels", ":From") .option("relationship.target.labels", ":To") .load() val collect = dataFrame.collect() val data = if (dataFrame.columns.contains("source.value") && dataFrame.columns.contains("target.value")) { collect .map(row => (row.getAs[Long]("source.value").toInt, row.getAs[Long]("target.value").toInt)) .sorted } else { Array.empty[(Int, Int)] } data.toList == (1 to 500).flatMap(v => (1 to 5).map(_ => (v, v))) } catch { case _: Throwable => false } }, Matchers.equalTo(true), 30L, TimeUnit.SECONDS ) SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("DROP CONSTRAINT From_value") tx.run("DROP CONSTRAINT To_value") } } ) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jSkipNullKeysTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkException import org.apache.spark.sql.SaveMode import org.junit.Assume import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.caniuse.CanIUse import org.neo4j.caniuse.Schema class DataSourceWriterNeo4jSkipNullKeysTSE extends SparkConnectorScalaBaseTSE { import ss.implicits._ @Test def `fails to write nodes when key properties contain null values`(): Unit = { val cities = Seq( (Some(1), "Cherbourg en Cotentin"), (Some(2), "London"), (Some(3), "Malmö"), (None, "Moon") ).toDF("id", "city") val caught = intercept[SparkException] { cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":City") .option("node.keys", "id") .option("schema.optimization.node.keys", "KEY") .save() } assert(caught.getMessage contains "Cannot merge the following node because of null property value") } @Test def `fails to write relationships when source node key properties contain null values`(): Unit = { val caught = intercept[SparkException] { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(2), Some(3), "Turkish Airlines"), (None, Some(5), "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .save() } assert(caught.getMessage contains "Cannot merge the following node because of null property value") } @Test def `fails to write relationships when target node key properties contain null values`(): Unit = { val caught = intercept[SparkException] { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(2), Some(3), "Turkish Airlines"), (Some(3), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .save() } assert(caught.getMessage contains "Cannot merge the following node because of null property value") } @Test def `fails to write relationships when relationship key properties contain null values`(): Unit = { Assume.assumeTrue( CanIUse.INSTANCE.canIUse(Schema.INSTANCE.relationshipKeyConstraints()).withNeo4j(SparkConnectorScalaSuiteIT.neo4j) ) val caught = intercept[SparkException] { val cities = Seq( (Some(1), Some(2), Some("BA721"), "British Airways"), (Some(2), Some(3), Some("TK211"), "Turkish Airlines"), (Some(3), Some(4), None, "Another Airline") ).toDF("from", "to", "flight", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.keys", "flight") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .option("schema.optimization.relationship.keys", "KEY") .save() } assert(caught.getMessage contains "Cannot merge the following relationship because of null property value") } @Test def `skips nodes when key properties contain null values with APPEND mode`(): Unit = { val cities = Seq( (Some(1), "Cherbourg en Cotentin"), (Some(2), "London"), (Some(3), "Malmö"), (None, "Moon") ).toDF("id", "city") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":City") .option("node.keys", "id") .option("node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val result = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(result == 3) } } @Test def `skips nodes when key properties contain null values with OVERWRITE mode`(): Unit = { val cities = Seq( (Some(1), "Cherbourg en Cotentin"), (Some(2), "London"), (Some(3), "Malmö"), (None, "Moon") ).toDF("id", "city") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":City") .option("node.keys", "id") .option("node.keys.skip.nulls", "true") .option("schema.optimization.node.keys", "KEY") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val result = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(result == 3) } } @Test def `skips relationships when source or target node key properties contain null values`(): Unit = { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(2), Some(3), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 3) val citiesWithId5 = session.run("MATCH (n:City {id: 5}) RETURN count(n) as count") .single() .get("count") .asLong() assert(citiesWithId5 == 0) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when source or target node key properties contain null values when nodes are matched`() : Unit = { use(SparkConnectorScalaSuiteIT.driver.session()) { session => session.run("UNWIND [1,2,3,5] AS id CREATE (:City {id: id})").consume() } val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(2), Some(3), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Match") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Match") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 4) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when source or target node key properties contain null values when nodes are appended`() : Unit = { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(3), Some(4), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Append") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Append") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 4) val citiesWithId5 = session.run("MATCH (n:City {id: 5}) RETURN count(n) as count") .single() .get("count") .asLong() assert(citiesWithId5 == 0) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when source or target node key properties contain null values with append mode`(): Unit = { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(3), Some(4), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 4) val citiesWithId5 = session.run("MATCH (n:City {id: 5}) RETURN count(n) as count") .single() .get("count") .asLong() assert(citiesWithId5 == 0) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when source or target node key properties contain null values when nodes are matched with append mode`() : Unit = { use(SparkConnectorScalaSuiteIT.driver.session()) { session => session.run("UNWIND [1,2,3,5] AS id CREATE (:City {id: id})").consume() } val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(2), Some(3), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Match") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Match") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 4) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when source or target node key properties contain null values when nodes are appended with append mode`() : Unit = { val cities = Seq( (Some(1), Some(2), "British Airways"), (Some(3), Some(4), "Turkish Airlines"), (None, Some(5), "Another Airline"), (Some(5), None, "Another Airline") ).toDF("from", "to", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Append") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Append") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.properties", "airline") .option("relationship.source.node.keys.skip.nulls", "true") .option("relationship.target.node.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 4) val citiesWithId5 = session.run("MATCH (n:City {id: 5}) RETURN count(n) as count") .single() .get("count") .asLong() assert(citiesWithId5 == 0) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } @Test def `skips relationships when relationship key properties contain null values`(): Unit = { Assume.assumeTrue( CanIUse.INSTANCE.canIUse(Schema.INSTANCE.relationshipKeyConstraints()).withNeo4j(SparkConnectorScalaSuiteIT.neo4j) ) val cities = Seq( (Some(1), Some(2), Some("BA721"), "British Airways"), (Some(2), Some(3), Some("TK211"), "Turkish Airlines"), (Some(3), Some(5), None, "Another Airline") ).toDF("from", "to", "flight", "airline") cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "FLIES_TO") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":City") .option("relationship.source.node.keys", "from:id") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":City") .option("relationship.target.node.keys", "to:id") .option("relationship.keys", "flight") .option("relationship.properties", "airline") .option("schema.optimization.node.keys", "KEY") .option("schema.optimization.relationship.keys", "KEY") .option("relationship.keys.skip.nulls", "true") .save() use(SparkConnectorScalaSuiteIT.driver.session()) { session => val cities = session.run("MATCH (n:City) RETURN count(n) as count") .single() .get("count") .asLong() assert(cities == 3) val citiesWithId5 = session.run("MATCH (n:City {id: 5}) RETURN count(n) as count") .single() .get("count") .asLong() assert(citiesWithId5 == 0) val flies = session.run("MATCH ()-[r:FLIES_TO]->() RETURN count(r) as count") .single() .get("count") .asLong() assert(flies == 2) } } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.scheduler.SparkListener import org.apache.spark.scheduler.SparkListenerStageCompleted import org.apache.spark.scheduler.SparkListenerStageSubmitted import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SaveMode import org.hamcrest.Matchers import org.junit.Assert.assertEquals import org.junit.Assert.assertNotNull import org.junit.Assert.assertTrue import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.driver.Session import org.neo4j.driver.SessionConfig import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import org.neo4j.spark.writer.DataWriterMetrics import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { import ss.implicits._ @Test def `should read and write relations with append mode`(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * $total, name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin use(SparkConnectorScalaSuiteIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE db1 WAIT 30 seconds").consume() session.run("CREATE OR REPLACE DATABASE db2 WAIT 30 seconds").consume() } use(SparkConnectorScalaSuiteIT.session("db1")) { session => session.run(fixtureQuery).consume() } use(SparkConnectorScalaSuiteIT.session("db2")) { session => session .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("CREATE CONSTRAINT person_id FOR (p:Person) REQUIRE p.id IS UNIQUE") tx.run("CREATE CONSTRAINT product_id FOR (p:Product) REQUIRE p.id IS UNIQUE") } } ) } try { val dfOriginal: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.source.save.mode", "Append") .option("relationship.target.labels", ":Product") .option("relationship.target.save.mode", "Append") .option("batch.size", "11") .save() // let's write again to prove that 2 relationship are being added dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.node.keys", "source.id:id") .option("relationship.target.labels", ":Product") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.node.keys", "target.id:id") .option("batch.size", "11") .save() val dfCopy = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() val dfOriginalCount = dfOriginal.count() assertEquals(dfOriginalCount * 2, dfCopy.count()) val resSourceOrig = dfOriginal.select("`source.id`").orderBy("`source.id`").collectAsList() val resSourceCopy = dfCopy.select("`source.id`").orderBy("`source.id`").collectAsList() val resTargetOrig = dfOriginal.select("`target.id`").orderBy("`target.id`").collectAsList() val resTargetCopy = dfCopy.select("`target.id`").orderBy("`target.id`").collectAsList() for (i <- 0 until 1) { assertEquals( resSourceOrig.get(i).getLong(0), resSourceCopy.get(i).getLong(0) ) assertEquals( resTargetOrig.get(i).getLong(0), resTargetCopy.get(i).getLong(0) ) } assertEquals( 2, dfCopy.where("`source.id` = 1").count() ) } finally { SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db2")) .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("DROP CONSTRAINT person_id") tx.run("DROP CONSTRAINT product_id") } } ) } } @Test def `should read and write relations with overwrite mode`(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * $total, name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin use(SparkConnectorScalaSuiteIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE db1 WAIT 30 seconds").consume() session.run("CREATE OR REPLACE DATABASE db2 WAIT 30 seconds").consume() } use(SparkConnectorScalaSuiteIT.session("db1")) { session => session.run(fixtureQuery).consume() } use(SparkConnectorScalaSuiteIT.session("db2")) { session => session .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("CREATE CONSTRAINT person_id FOR (p:Person) REQUIRE p.id IS UNIQUE") tx.run("CREATE CONSTRAINT product_id FOR (p:Product) REQUIRE p.id IS UNIQUE") } } ) } try { val dfOriginal: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.labels", ":Product") .option("relationship.target.save.mode", "Overwrite") .option("batch.size", "11") .save() // let's write the same thing again to prove there will be just one relation dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.source.node.keys", "source.id:id") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.labels", ":Product") .option("relationship.target.node.keys", "target.id:id") .option("relationship.target.save.mode", "Overwrite") .option("batch.size", "11") .save() val dfCopy = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") val dfOriginalCount = dfOriginal.count() assertEquals(dfOriginalCount, dfCopy.count()) for (i <- 0 until 1) { assertEquals( dfOriginal.select("`source.id`").collectAsList().get(i).getLong(0), dfCopy.select("`source.id`").collectAsList().get(i).getLong(0) ) assertEquals( dfOriginal.select("`target.id`").collectAsList().get(i).getLong(0), dfCopy.select("`target.id`").collectAsList().get(i).getLong(0) ) } assertEquals( 1, dfCopy.where("`source.id` = 1").count() ) } finally { SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db2")) .writeTransaction( new TransactionWork[Unit] { override def execute(tx: Transaction): Unit = { tx.run("DROP CONSTRAINT person_id") tx.run("DROP CONSTRAINT product_id") } } ) } } @Test def `should read and write relations with MATCH and node keys`(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * $total, name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin use(SparkConnectorScalaSuiteIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE db1 WAIT 30 seconds").consume() session.run("CREATE OR REPLACE DATABASE db2 WAIT 30 seconds").consume() } use(SparkConnectorScalaSuiteIT.session("db1")) { session => session.run(fixtureQuery).consume() } use(SparkConnectorScalaSuiteIT.session("db2")) { session => session.run(fixtureQuery).consume() } val dfOriginal: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .option("relationship.source.node.keys", "source.id:id") .option("relationship.target.node.keys", "target.id:id") .option("relationship.source.save.mode", "Match") .option("relationship.target.save.mode", "Match") .option("batch.size", "11") .save() val dfCopy = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") for (i <- 0 until 1) { assertEquals( dfOriginal.select("`source.id`").collectAsList().get(i).getLong(0), dfCopy.select("`source.id`").collectAsList().get(i).getLong(0) ) assertEquals( dfOriginal.select("`target.id`").collectAsList().get(i).getLong(0), dfCopy.select("`target.id`").collectAsList().get(i).getLong(0) ) } } @Test def `should read and write relations with MERGE and node keys`(): Unit = { val total = 100 val fixtureQuery: String = s"""UNWIND range(1, $total) as id |CREATE (pr:Product {id: id * $total, name: 'Product ' + id}) |CREATE (pe:Person {id: id, fullName: 'Person ' + id}) |CREATE (pe)-[:BOUGHT{when: rand(), quantity: rand() * 1000}]->(pr) |RETURN * """.stripMargin use(SparkConnectorScalaSuiteIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE db1 WAIT 30 seconds").consume() session.run("CREATE OR REPLACE DATABASE db2 WAIT 30 seconds").consume() } use(SparkConnectorScalaSuiteIT.session("db1")) { session => session.run(fixtureQuery).consume() } val dfOriginal: DataFrame = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "BOUGHT") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") dfOriginal.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .option("relationship.source.node.keys", "source.id:id") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.node.keys", "target.id:id") .option("relationship.target.save.mode", "Overwrite") .option("batch.size", "11") .save() val dfCopy = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db2") .option("relationship", "SOLD") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Person") .option("relationship.target.labels", ":Product") .load() .orderBy("`source.id`", "`target.id`") for (i <- 0 until 1) { assertEquals( dfOriginal.select("`source.id`").collectAsList().get(i).getLong(0), dfCopy.select("`source.id`").collectAsList().get(i).getLong(0) ) assertEquals( dfOriginal.select("`target.id`").collectAsList().get(i).getLong(0), dfCopy.select("`target.id`").collectAsList().get(i).getLong(0) ) } } @Test def `should read relations and write relation with match mode`(): Unit = { val fixtureQuery: String = s"""CREATE (m:Musician {name: "John Bonham", age: 32}) |CREATE (i:Instrument {name: "Drums"}) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db1")) .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val musicDf = Seq( (12, 32, "John Bonham", "Drums") ).toDF("experience", "age", "name", "instrument") musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship.nodes.map", "false") .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Match") .option("relationship.target.save.mode", "Match") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name,age") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() val df2 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship.nodes.map", "false") .option("relationship", "PLAYS") .option("relationship.source.labels", ":Musician") .option("relationship.target.labels", ":Instrument") .load() val experience = df2.select("`source.age`").where("`source.name` = 'John Bonham'") .collectAsList().get(0).getLong(0) assertEquals(32, experience) } @Test def `should give a more clear error if properties or keys are inverted`(): Unit = { val musicDf = Seq( (1, 12, "John Henry Bonham", "Drums"), (2, 19, "John Mayer", "Guitar"), (3, 32, "John Scofield", "Guitar"), (4, 15, "John Butler", "Guitar") ).toDF("id", "experience", "name", "instrument") try { musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "musician_name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() } catch { case exception: IllegalArgumentException => { val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing musician_name from option `relationship.source.node.keys` | |The option key and value might be inverted.""".stripMargin )) } case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead") } } @Test def `should give a more clear error if properties or keys are inverted on different options`(): Unit = { val musicDf = Seq( (1, 12, "John Henry Bonham", "Drums"), (2, 19, "John Mayer", "Guitar"), (3, 32, "John Scofield", "Guitar"), (4, 15, "John Butler", "Guitar") ).toDF("id", "experience", "name", "instrument") try { musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "musician_name:name,another_name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument_name:name") .save() } catch { case exception: IllegalArgumentException => { val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing instrument_name from option `relationship.target.node.keys` | - Schema is missing musician_name, another_name from option `relationship.source.node.keys` | |The option key and value might be inverted.""".stripMargin )) } case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead") } } @Test def `should give a more clear error if node properties or keys are inverted`(): Unit = { val musicDf = Seq( (1, 12, "John Henry Bonham", "Drums"), (2, 19, "John Mayer", "Guitar"), (3, 32, "John Scofield", "Guitar"), (4, 15, "John Butler", "Guitar") ).toDF("id", "experience", "name", "instrument") try { musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("labels", "Person") .option("node.properties", "musician_name:name,another_name:name") .save() } catch { case exception: IllegalArgumentException => { val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing instrument_name from option `node.properties` | |The option key and value might be inverted.""".stripMargin )) } case generic: Throwable => fail( s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead: ${generic.getMessage}" ) } } @Test def `exports write metrics`(): Unit = { val input = List("Ali", "Andrea", "Eugene", "Florent") val query = "CREATE (:Name {name: event.name})-[:STARTS_WITH]->(:Letter {value: left(event.name, 1)})" val expectedMetrics = Map( DataWriterMetrics.RECORDS_WRITTEN_DESCRIPTION -> 4, DataWriterMetrics.RELATIONSHIPS_CREATED_DESCRIPTION -> 4, DataWriterMetrics.NODES_CREATED_DESCRIPTION -> 8, DataWriterMetrics.PROPERTIES_SET_DESCRIPTION -> 8 ) val metrics = new AtomicReference[Map[String, Any]]() val listener = new MetricsListener(expectedMetrics.keySet, metrics) ss.sparkContext.addSparkListener(listener) try { input.toDF("name") .repartition(1) .write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("database", "db1") .option("batch.size", 1) // TODO: remove this when https://issues.apache.org/jira/browse/SPARK-45759 is fixed .option("query", query) .save() val db1Session = SparkConnectorScalaSuiteIT.driver.session(SessionConfig.forDatabase("db1")) Assert.assertEventually( () => { db1Session.run( "MATCH (:Name)-[r:STARTS_WITH]->(:Letter) RETURN count(r) as cnt" ).single().get("cnt").asLong() }, Matchers.equalTo(4L), 30L, TimeUnit.SECONDS ) assertNotNull(metrics.get()) assertEquals(metrics.get(), expectedMetrics) } finally { ss.sparkContext.removeSparkListener(listener) } } @Test def `does not create constraint if schema validation fails`(): Unit = { val cities = Seq( (1, "Cherbourg en Cotentin"), (2, "London"), (3, "Malmö") ).toDF("id", "city") try { cities.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":News") .option("node.keys", "newsId") .option("schema.optimization.node.keys", "UNIQUE") .save() } catch { case _: Exception => {} } var session: Session = null try { session = SparkConnectorScalaSuiteIT.driver.session() val result = session.run("SHOW CONSTRAINTS YIELD labelsOrTypes WHERE labelsOrTypes[0] = 'News' RETURN count(*) AS count") .single() .get("count") .asLong() assertEquals(0, result) } finally { if (session != null) { session.close() } } } class MetricsListener(names: Set[String], captureMetrics: AtomicReference[Map[String, Any]]) extends SparkListener { override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { stageSubmitted.stageInfo.accumulables.clear() } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { captureMetrics.set(stageCompleted .stageInfo .accumulables .values .filter(metric => metric.name.exists(names.contains)) .map(metric => (metric.name.get, metric.value.get)) .toMap) } } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import junitparams.JUnitParamsRunner import junitparams.Parameters import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.SparkException import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.ByteType import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.DayTimeIntervalType import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.types.YearMonthIntervalType import org.junit import org.junit.Assert._ import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.neo4j.driver.Result import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.Value import org.neo4j.driver.exceptions.ClientException import org.neo4j.driver.exceptions.value.Uncoercible import org.neo4j.driver.internal.InternalPoint2D import org.neo4j.driver.internal.InternalPoint3D import org.neo4j.driver.internal.types.InternalTypeSystem import org.neo4j.driver.summary.ResultSummary import org.neo4j.driver.types.IsoDuration import org.neo4j.driver.types.Type import org.neo4j.spark.RowUtil.getByName import org.neo4j.spark.util.Neo4jOptions import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.must.Matchers.include import org.scalatest.matchers.must.Matchers.the import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import java.time.LocalTime import java.time.OffsetTime import java.time.ZoneId import java.time.ZoneOffset import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.immutable.ListMap import scala.language.postfixOps import scala.math.Ordering.Implicits.infixOrderingOps import scala.util.Random abstract class Neo4jType(`type`: String) case class Duration(months: Long, days: Long, seconds: Long, nanoseconds: Long, `type`: String = "duration") extends Neo4jType(`type`) case class Point2d(`type`: String = "point-2d", srid: Int, x: Double, y: Double) extends Neo4jType(`type`) case class Point3d(`type`: String = "point-3d", srid: Int, x: Double, y: Double, z: Double) extends Neo4jType(`type`) case class Time(`type`: String = "offset-time", value: String) extends Neo4jType(`type`) case class LocalTimeValue(`type`: String = "local-time", value: String) extends Neo4jType(`type`) case class Person(name: String, surname: String, age: Int, livesIn: Point3d) case class Person_TimeAndLocalTime(name: String, time: Time, localTime: LocalTimeValue) case class SimplePerson(name: String, surname: String) case class EmptyRow[T](data: T) @RunWith(classOf[JUnitParamsRunner]) class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE { val sparkSession = SparkSession.builder() .master("local[*]") .appName("DataSourceWriterTSE") .getOrCreate() import sparkSession.implicits._ private def testType[T](ds: DataFrame, neo4jType: Type): Unit = { ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.foo AS foo |""".stripMargin ).list().asScala .filter(r => r.get("foo").hasType(neo4jType)) .map(r => r.asMap().asScala) .toSet val expected = ds.collect() .map(row => Map("foo" -> { val foo = row.getAs[T]("foo") foo match { case sqlDate: java.sql.Date => sqlDate.toLocalDate case sqlTimestamp: java.sql.Timestamp => sqlTimestamp.toInstant.atZone(ZoneOffset.UTC) case _ => foo } }) ) .toSet assertEquals(expected, records) } private def testArray[T](ds: DataFrame): Unit = { ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.foo AS foo |""".stripMargin ).list().asScala .filter(r => r.get("foo").hasType(InternalTypeSystem.TYPE_SYSTEM.LIST())) .map(r => r.get("foo").asList()) .toSet val expected = ds.collect() .map(row => row.getList[T](0)) .toSet assertEquals(expected, records) } private def testDurationType(ds: DataFrame, expected: Set[Duration]): Unit = { ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Duration") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (d:Duration) |RETURN d.duration AS duration |""".stripMargin ).list().asScala .filter(r => r.get("duration").hasType(InternalTypeSystem.TYPE_SYSTEM.DURATION())) .map(r => r.get("duration").asIsoDuration()) .map(data => Duration(data.months, data.days, data.seconds, data.nanoseconds)) .toSet assertEquals(expected, records) } @Test def testThrowsExceptionIfNoValidReadOptionIsSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .load() .show() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case e: IllegalArgumentException => assertEquals("No valid option found. One of `GDS`, `LABELS`, `QUERY`, `RELATIONSHIP` is required", e.getMessage) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testThrowsExceptionIfTwoValidReadOptionAreSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .option("relationship", "KNOWS") .load() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case e: IllegalArgumentException => assertEquals( "You need to specify just one of these options: 'gds', 'labels', 'query', 'relationship'", e.getMessage ) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def testThrowsExceptionIfThreeValidReadOptionAreSet(): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .option("relationship", "KNOWS") .option("query", "MATCH (n) RETURN n") .load() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case e: IllegalArgumentException => assertEquals( "You need to specify just one of these options: 'gds', 'labels', 'query', 'relationship'", e.getMessage ) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test def `should write nodes with string values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => i.toString) .toDF("foo") testType[String](ds, InternalTypeSystem.TYPE_SYSTEM.STRING()) } @Test def `should write nodes with string array values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => i.toString) .map(i => Array(i, i)) .toDF("foo") testArray[String](ds) } @Test def `should write nodes with int values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => i) .toDF("foo") testType[Int](ds, InternalTypeSystem.TYPE_SYSTEM.INTEGER()) } @Test def `should write nodes with byte values into Neo4j`(): Unit = { val ds = (1 to 10) .map(_.toByte) .toDF("foo") testType[Byte](ds, InternalTypeSystem.TYPE_SYSTEM.INTEGER()) } @Test def `should write nodes with short values into Neo4j`(): Unit = { val ds = (1 to 10) .map(_.toShort) .toDF("foo") testType[Short](ds, InternalTypeSystem.TYPE_SYSTEM.INTEGER()) } @Test def `should write nodes with date values into Neo4j`(): Unit = { val ds = (1 to 5) .map(i => java.sql.Date.valueOf("2020-01-0" + i)) .toDF("foo") testType[java.sql.Date](ds, InternalTypeSystem.TYPE_SYSTEM.DATE()) } @Test def `should write nodes with timestamp values into Neo4j`(): Unit = { val ds = (1 to 5) .map(i => java.sql.Timestamp.valueOf(s"2020-01-0$i 11:11:11.11")) .toDF("foo") testType[java.sql.Timestamp](ds, InternalTypeSystem.TYPE_SYSTEM.DATE_TIME()) } @Test def `should write nodes with timestampNTZ values into Neo4j`(): Unit = { val ds = (1 to 5) .map(i => java.time.LocalDateTime.of(2020, 1, i, 11, 11, 11, 111000000)) .toDF("foo") testType[java.time.LocalDateTime](ds, InternalTypeSystem.TYPE_SYSTEM.LOCAL_DATE_TIME()) } @Test def `should write nodes with int array values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => i.toLong) .map(i => Array(i, i)) .toDF("foo") testArray[Long](ds) } @Test def `should write nodes with point-2d values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => EmptyRow(Point2d(srid = 4326, x = Random.nextDouble(), y = Random.nextDouble()))) .toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.data AS data |""".stripMargin ).list().asScala .filter(r => r.get("data").hasType(InternalTypeSystem.TYPE_SYSTEM.POINT())) .map(r => { val point = r.get("data").asPoint() (point.srid(), point.x(), point.y()) }) .toSet val expected = ds.collect() .map(point => (point.data.srid, point.data.x, point.data.y)) .toSet assertEquals(expected, records) } @Test def `should write nodes with point-2d array values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => EmptyRow(Seq( Point2d(srid = 4326, x = Random.nextDouble(), y = Random.nextDouble()), Point2d(srid = 4326, x = Random.nextDouble(), y = Random.nextDouble()) )) ) .toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.data AS data |""".stripMargin ).list().asScala .filter(r => r.get("data").hasType(InternalTypeSystem.TYPE_SYSTEM.LIST())) .map(r => r.get("data") .asList.asScala .map(_.asInstanceOf[InternalPoint2D]) .map(point => (point.srid(), point.x(), point.y())) ) .toSet val expected = ds.collect() .map(row => row.data.map(p => (p.srid, p.x, p.y))) .toSet assertEquals(expected, records) } @Test def `should write nodes with point-3d values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => EmptyRow(Point3d(srid = 4979, x = Random.nextDouble(), y = Random.nextDouble(), z = Random.nextDouble())) ) .toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.data AS data |""".stripMargin ).list().asScala .filter(r => r.get("data").hasType(InternalTypeSystem.TYPE_SYSTEM.POINT())) .map(r => { val point = r.get("data").asPoint() (point.srid(), point.x(), point.y()) }) .toSet val expected = ds.collect() .map(point => (point.data.srid, point.data.x, point.data.y)) .toSet assertEquals(expected, records) } @Test def `should write nodes with point-3d array values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => EmptyRow(Seq( Point3d(srid = 4979, x = Random.nextDouble(), y = Random.nextDouble(), z = Random.nextDouble()), Point3d(srid = 4979, x = Random.nextDouble(), y = Random.nextDouble(), z = Random.nextDouble()) )) ) .toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p.data AS data |""".stripMargin ).list().asScala .filter(r => r.get("data").hasType(InternalTypeSystem.TYPE_SYSTEM.LIST())) .map(r => r.get("data") .asList.asScala .map(_.asInstanceOf[InternalPoint3D]) .map(point => (point.srid(), point.x(), point.y(), point.z())) ) .toSet val expected = ds.collect() .map(row => row.data.map(p => (p.srid, p.x, p.y, p.z))) .toSet assertEquals(expected, records) } @Test def `should write nodes with map values into Neo4j`(): Unit = { val ds = (1 to 10) .map(i => Map("field" + i -> i)) .toDF("foo") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNode:MyLabel") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:MyNode:MyLabel) |RETURN p |""".stripMargin ).list().asScala .filter(r => r.get("p").hasType(InternalTypeSystem.TYPE_SYSTEM.MAP())) .map(r => r.get("p").asMap().asScala) .toSet val expected = ds.collect().map(row => row.getMap[String, AnyRef](0)) .map(map => map.map(t => (s"foo.${t._1}", t._2)).toMap) .toSet assertEquals(expected, records) } @Test def `should write nodes with duration values into Neo4j from java period`(): Unit = { val range = 1 to 10 val ds = range .map(i => java.time.Period.ofMonths(i)) .toDF("duration") val expected = range .map(i => Duration(i, 0, 0, 0)) .toSet testDurationType(ds, expected) } @Test def `should write nodes with duration values into Neo4j from java duration`(): Unit = { val range = 1 to 10 val ds = range .map(i => java.time.Duration.ofDays(i.toLong)) .toDF("duration") val expected = range .map(i => Duration(0, i, 0, 0)) .toSet testDurationType(ds, expected) } @Test def `should write nodes with duration values into Neo4j from struct`(): Unit = { val range = 1 to 10 val ds = range .map(i => i.toLong) .map(i => EmptyRow(Duration(i, i, i, i))) .toDF("duration") val expected = range .map(i => Duration(i, i, i, i)) .toSet testDurationType(ds, expected) } @Test def `should write nodes with duration array values into Neo4j from struct`(): Unit = { val ds = (1 to 10) .map(i => i.toLong) .map(i => EmptyRow(Seq( Duration(i, i, i, i), Duration(i, i, i, i) )) ) .toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "BeanWithDuration") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:BeanWithDuration) |RETURN p.data AS data |""".stripMargin ).list().asScala .map(r => r.get("data") .asList.asScala .map(_.asInstanceOf[IsoDuration]) .map(data => (data.months, data.days, data.seconds, data.nanoseconds)) ) .toSet val expected = ds.collect() .map(row => row.data.map(data => (data.months, data.days, data.seconds, data.nanoseconds))) .toSet assertEquals(expected, records) } private case class DurationCase( intervalExpression: String, expectedDuration: Duration, expectedDt: Class[_ <: DataType] = classOf[DayTimeIntervalType] ) { private val isArithmetic = intervalExpression.startsWith("timestamp") val sql: String = if (isArithmetic) { intervalExpression } else { s"INTERVAL $intervalExpression" } } private def sqlDurationCases: java.util.List[DurationCase] = java.util.Arrays.asList( // DAY/TIME -> DayTimeIntervalType DurationCase("'3' DAY", Duration(0, 3, 0, 0)), DurationCase("'10 05' DAY TO HOUR", Duration(0, 10, 5L * 3600, 0)), DurationCase("'10 05:30' DAY TO MINUTE", Duration(0, 10, 5L * 3600 + 30L * 60, 0)), DurationCase("'10 05:30:15.123456' DAY TO SECOND", Duration(0, 10, 5L * 3600 + 30L * 60 + 15L, 123456000)), DurationCase("'12' HOUR", Duration(0, 0, 12L * 3600, 0)), DurationCase("'12:34' HOUR TO MINUTE", Duration(0, 0, 12L * 3600 + 34L * 60, 0)), DurationCase("'12:34:56.123456' HOUR TO SECOND", Duration(0, 0, 12L * 3600 + 34L * 60 + 56L, 123456000)), DurationCase("'42' MINUTE", Duration(0, 0, 42L * 60, 0)), DurationCase("'42:07.001002' MINUTE TO SECOND", Duration(0, 0, 42L * 60 + 7L, 1002000)), DurationCase("'59.000001' SECOND", Duration(0, 0, 59L, 1000)), DurationCase( "timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00')", Duration(0, 367, 66600L, 454000000) ), // YEAR/MONTH -> YearMonthIntervalType DurationCase("'3' YEAR", Duration(36, 0, 0, 0), classOf[YearMonthIntervalType]), DurationCase("'7' MONTH", Duration(7, 0, 0, 0), classOf[YearMonthIntervalType]), DurationCase("'4-5' YEAR TO MONTH", Duration(53, 0, 0, 0), classOf[YearMonthIntervalType]) ) @Test @Parameters(method = "sqlDurationCases") def `interval SQL literals map to native neo4j durations`(testCase: DurationCase): Unit = { val id = java.util.UUID.randomUUID().toString val df = sparkSession.sql(s"SELECT '$id' AS id, ${testCase.sql} AS duration") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Dur") .save() val wantType = testCase.expectedDt.getSimpleName val gotType = df.schema("duration").dataType assertTrue(s"expected Spark to pick $wantType but it was $gotType", testCase.expectedDt.isInstance(gotType)) val gotDuration = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (d:Dur {id: '$id'}) |RETURN d.duration AS duration |""".stripMargin ).single().get("duration").asIsoDuration() assertEquals(s"${testCase.sql} -> months", testCase.expectedDuration.months, gotDuration.months) assertEquals(s"${testCase.sql} -> days", testCase.expectedDuration.days, gotDuration.days) assertEquals(s"${testCase.sql} -> seconds", testCase.expectedDuration.seconds, gotDuration.seconds) assertEquals(s"${testCase.sql} -> nanos", testCase.expectedDuration.nanoseconds, gotDuration.nanoseconds) } private val sqlDurationArrayCases: java.util.List[Seq[DurationCase]] = java.util.Arrays.asList( Seq( DurationCase("'10 05:30:15.123' DAY TO SECOND", null), DurationCase("'0 00:00:01.000' DAY TO SECOND", null) ), Seq( DurationCase("timestamp('2024-01-02 00:00:00') - timestamp('2024-01-01 00:00:00')", null), DurationCase("timestamp('2024-01-01 00:00:00') - current_timestamp()", null) ), Seq( DurationCase("'1-02' YEAR TO MONTH", null, classOf[YearMonthIntervalType]), DurationCase("'0-11' YEAR TO MONTH", null, classOf[YearMonthIntervalType]) ) ) @Test @Parameters(method = "sqlDurationArrayCases") def `interval SQL arrays map to native neo4j durations arrays`(testCase: Seq[DurationCase]): Unit = { val id = java.util.UUID.randomUUID().toString val expectedDt = testCase.head.expectedDt val sqlArray = testCase.map(_.sql).mkString("array(", ", ", ")") val df = sparkSession.sql(s"SELECT '$id' AS id, $sqlArray AS durations") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "DurArr") .save() val gotType = df.schema("durations").dataType assertTrue( s"expected Spark to infer ArrayType(${expectedDt.getSimpleName}) but it was $gotType", gotType match { case ArrayType(et, _) if expectedDt.isInstance(et) => true case _ => false } ) val result = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (d:DurArr {id: '$id'}) |RETURN d.durations AS durations |""".stripMargin ).single().get("durations") assertTrue( s"expected successful conversion to IsoDuration array, but it failed: $result", try { val _ = result.asList((v: Value) => v.asIsoDuration()) true } catch { case _: Uncoercible => false case e => throw e } ) } @Test def `should write TINYINT as neo4j integer`(): Unit = { val id = java.util.UUID.randomUUID().toString val df = sparkSession.sql(s"SELECT '$id' AS id, CAST(5 AS TINYINT) AS byte") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Byte") .save() val wantType = DataTypes.ByteType val gotType = df.schema("byte").dataType assertTrue(s"expected Spark to pick ${wantType.simpleString} but it was $gotType", wantType == gotType) val gotByte = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (b:Byte {id: '$id'}) |RETURN b.byte AS byte |""".stripMargin ).single().get("byte").asInt() assertEquals(5, gotByte) } @Test def `should write BINARY (byte array) as neo4j ByteArray`(): Unit = { val id = java.util.UUID.randomUUID().toString val sqlArray = (1 to 10).map(i => s"CAST($i AS TINYINT)").mkString("array(", ", ", ")") val df = sparkSession.sql(s"SELECT '$id' AS id, $sqlArray AS binary") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Binary") .save() val wantType: DataType = DataTypes.ByteType val gotType = df.schema("binary").dataType assertTrue( s"expected Spark to infer ArrayType(${wantType.simpleString}) but it was $gotType", gotType match { case ArrayType(_: ByteType, _) => true case _ => false } ) val gotByteArray = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (b:Binary {id: '$id'}) |RETURN b.binary AS binary |""".stripMargin ).single().get("binary").asByteArray() assertEquals(10, gotByteArray.length) for (b <- gotByteArray.indices) { val expectedValue = (b + 1).toByte assertEquals(expectedValue, gotByteArray(b)) } } @Test def `should write SMALLINT as neo4j integer`(): Unit = { val id = java.util.UUID.randomUUID().toString val df = sparkSession.sql(s"SELECT '$id' AS id, CAST(5 AS SMALLINT) AS short") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Short") .save() val wantType = DataTypes.ShortType val gotType = df.schema("short").dataType assertTrue(s"expected Spark to pick ${wantType.simpleString} but it was $gotType", wantType == gotType) val gotByte = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (b:Short {id: '$id'}) |RETURN b.short AS short |""".stripMargin ).single().get("short").asInt() assertEquals(5, gotByte) } @Test def `should write DECIMAL as neo4j string`(): Unit = { val id = java.util.UUID.randomUUID().toString val df = sparkSession.sql(s"SELECT '$id' AS id, CAST(5.42 AS DECIMAL(10, 2)) AS decimal") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Decimal") .save() val wantType = DecimalType(10, 2) val gotType = df.schema("decimal").dataType assertTrue(s"expected Spark to pick $wantType but it was $gotType", wantType.typeName == gotType.simpleString) val gotDecimal = SparkConnectorScalaSuiteIT.session().run( s"""MATCH (b:Decimal {id: '$id'}) |RETURN b.decimal AS decimal |""".stripMargin ).single().get("decimal").asString assertEquals("5.42", gotDecimal) } @Test def `should write nodes into Neo4j with points`(): Unit = { val total = 10 val rand = Random val ds = (1 to total) .map(i => Person( name = "Andrea " + i, "Santurbano " + i, rand.nextInt(100), Point3d(srid = 4979, x = 12.5811776, y = 41.9579492, z = 1.3) ) ).toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person: Customer") .save() val count = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |WHERE p.name STARTS WITH 'Andrea' |AND p.surname STARTS WITH 'Santurbano' |RETURN count(p) AS count |""".stripMargin ).single().get("count").asInt() assertEquals(total, count) val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |WHERE p.name STARTS WITH 'Andrea' |AND p.surname STARTS WITH 'Santurbano' |RETURN p.name AS name, p.surname AS surname, p.age AS age, | p.bornIn AS bornIn, p.livesIn AS livesIn |""".stripMargin ).list().asScala .filter(r => { val map: java.util.Map[String, Object] = r.asMap() (map.get("name").isInstanceOf[String] && map.get("surname").isInstanceOf[String] && map.get("livesIn").isInstanceOf[InternalPoint3D] && map.get("age").isInstanceOf[Long]) }) assertEquals(total, records.size) } @Test def `should write nodes into Neo4j with Time and LocalTime Types`(): Unit = { val total = 1 val rand = Random val ds = (1 to total) .map(i => Person_TimeAndLocalTime( name = "Andrea", time = Time(value = "12:50:35.556000000+01:00"), localTime = LocalTimeValue(value = "12:50:35.556000000") ) ).toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("node.keys", "name") .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person_TimeAndLocalTime") .save() val count = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person_TimeAndLocalTime) |WHERE p.name STARTS WITH 'Andrea' |RETURN count(p) AS count |""".stripMargin ).single().get("count").asInt() assertEquals(total, count) val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person_TimeAndLocalTime) |WHERE p.name STARTS WITH 'Andrea' |RETURN p.name AS name, p.time AS time, p.localTime AS localTime |""".stripMargin ).list().asScala .filter(r => { val map: java.util.Map[String, Object] = r.asMap() (map.get("name").isInstanceOf[String] && map.get("time").isInstanceOf[OffsetTime] && map.get("localTime").isInstanceOf[LocalTime]) }) assertEquals(total, records.size) } @Test def `should throw an error because the node already exists`(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE") }) SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("CREATE (p:Person{name: 'Andrea', surname: 'Santurbano'})") }) val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS() try { val thrown = the[SparkException] thrownBy { ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .save() // we need the action to be able to trigger the exception because of the changes in Spark 3 } thrown.getMessage should include("org.neo4j.driver.exceptions.ClientException") val rootCause = ExceptionUtils.getRootCause(thrown) // root cause is not always returned as a ClientException so we pass it through pattern matching to remove flakiness rootCause match { case c: ClientException => c.code() should be("Neo.ClientError.Schema.ConstraintValidationFailed") case _ => } } finally { SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("DROP CONSTRAINT person_surname") }) } } @Test def `should update the node that already exists`(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE") }) SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("CREATE (p:Person{name: 'Federico', surname: 'Santurbano'})") }) val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .option("node.keys", "surname") .save() val nodeList = SparkConnectorScalaSuiteIT.session() .run( """MATCH (n:Person{surname: 'Santurbano'}) |RETURN n |""".stripMargin ) .list() .asScala assertEquals(1, nodeList.size) assertEquals("Andrea", nodeList.head.get("n").asNode().get("name").asString()) SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("DROP CONSTRAINT person_surname") }) } @Test def `should skip null properties`(): Unit = { val ds = Seq(SimplePerson("Andrea", null)).toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .save() val nodeList = SparkConnectorScalaSuiteIT.session() .run( """MATCH (n:Person{name: 'Andrea'}) |RETURN n |""".stripMargin ) .list() .asScala assertEquals(1, nodeList.size) val node = nodeList.head.get("n").asNode() assertFalse("surname should not exist", node.asMap().containsKey("surname")) } @Test def `should throw an error because SaveMode.Overwrite need node.keys`(): Unit = { val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS() try { ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", "Person") .save() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case illegalArgumentException: IllegalArgumentException => { assertTrue(illegalArgumentException.getMessage.equals( s"${Neo4jOptions.NODE_KEYS} is required when Save Mode is Overwrite" )) } case e: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName} but is ${e.getClass.getSimpleName}") } } @Test def `should write within partitions`(): Unit = { val ds = (1 to 100).map(i => Person("Andrea " + i, "Santurbano " + i, 36, null)).toDS() .repartition(10) ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("batch.size", "11") .save() val count = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (p:Person:Customer) |WHERE p.name STARTS WITH 'Andrea' |AND p.surname STARTS WITH 'Santurbano' |RETURN count(p) AS count |""".stripMargin ).single().get("count").asInt() assertEquals(100, count) val keys = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (p:Person:Customer) |WHERE p.name STARTS WITH 'Andrea' |AND p.surname STARTS WITH 'Santurbano' |RETURN DISTINCT keys(p) AS keys |""".stripMargin ).single().get("keys").asList() assertEquals(Set("name", "surname", "age"), keys.asScala.toSet) } @Test @Ignore("This won't work right now because we can't know if we are in a Write or Read context") def `should throw an exception for a read only query`(): Unit = { val ds = (1 to 100).map(i => Person("Andrea " + i, "Santurbano " + i, 36, null)).toDS() try { ds.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "MATCH (r:Read) RETURN r") .option("batch.size", "11") .save() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case illegalArgumentException: IllegalArgumentException => assertTrue(illegalArgumentException.getMessage.equals("Please provide a valid WRITE query")) case t: Throwable => fail( s"should be thrown a ${classOf[IllegalArgumentException].getName}, but it's ${t.getClass.getSimpleName}: ${t.getMessage}" ) } } @Test def `should insert data with a custom query`(): Unit = { val ds = (1 to 100).map(i => Person("Andrea " + i, "Santurbano " + i, 36, null)).toDS() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "CREATE (n:MyNode{fullName: event.name + event.surname, age: event.age - 10})") .option("batch.size", "11") .save() val count = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (p:MyNode) |WHERE p.fullName CONTAINS 'Andrea' |AND p.fullName CONTAINS 'Santurbano' |AND p.age = 26 |RETURN count(p) AS count |""".stripMargin ).single().get("count").asLong() assertEquals(ds.count(), count) } @Test def `should handle unusual column names`(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("CREATE CONSTRAINT instrument_name FOR (i:Instrument) REQUIRE i.name IS UNIQUE") }) val musicDf = Seq( (12, "John Bonham", "Drums", "f``````oo"), (19, "John Mayer", "Guitar", "bar"), (32, "John Scofield", "Guitar", "ba` z"), (15, "John Butler", "Guitar", "qu ux") ).toDF("experience", "name", "instrument", "fi``(╯°□°)╯︵ ┻━┻eld") musicDf.write .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("relationship", "PLAYS") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.node.keys", "name") .option("relationship.source.node.properties", "fi``(╯°□°)╯︵ ┻━┻eld:field") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .option("relationship.target.save.mode", "Overwrite") .save() SparkConnectorScalaSuiteIT.session() .writeTransaction(new TransactionWork[Result] { override def execute(transaction: Transaction): Result = transaction.run("DROP CONSTRAINT instrument_name") }) val musicDfCheck = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.nodes.map", "false") .option("relationship.source.labels", ":Musician") .option("relationship.target.labels", ":Instrument") .load() val size = musicDfCheck.count assertEquals(4, size) val res = musicDfCheck.orderBy("`source.name`").collectAsList() assertEquals("John Bonham", res.get(0).getString(4)) assertEquals("f``````oo", res.get(0).getString(5)) assertEquals("Drums", res.get(0).getString(8)) assertEquals("John Butler", res.get(1).getString(4)) assertEquals("qu ux", res.get(1).getString(5)) assertEquals("Guitar", res.get(1).getString(8)) assertEquals("John Mayer", res.get(2).getString(4)) assertEquals("bar", res.get(2).getString(5)) assertEquals("Guitar", res.get(2).getString(8)) assertEquals("John Scofield", res.get(3).getString(4)) assertEquals("ba` z", res.get(3).getString(5)) assertEquals("Guitar", res.get(3).getString(8)) } @Test(expected = classOf[SparkException]) def `should give error if native mode doesn't find a valid schema`(): Unit = { val musicDf = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ).toDF("experience", "name", "instrument") try { musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.save.strategy", "NATIVE") .option("relationship.source.labels", ":Person") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.labels", ":Instrument") .option("relationship.target.save.mode", "Overwrite") .save() // we need the action to be able to trigger the exception because of the changes in Spark 3 } catch { case sparkException: SparkException => { val clientException = ExceptionUtils.getRootCause(sparkException) assertTrue(clientException.getMessage.equals( "NATIVE write strategy requires a schema like: rel.[props], source.[props], target.[props]. " + "All of these columns are empty in the current schema." )) throw sparkException } case _: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}") } } @Test def `should write relations with KEYS mode`(): Unit = { val musicDf = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ).toDF("experience", "name", "instrument") musicDf.repartition(1).write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() val df2 = ss.read.format(classOf[DataSource].getName) .option("batch.size", 100) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship", "PLAYS") .option("relationship.source.labels", ":Musician") .option("relationship.target.labels", ":Instrument") .load() assertEquals(4, df2.count()) val res = df2.orderBy("`source.name`").collectAsList() assertEquals("John Bonham", res.get(0).getString(4)) assertEquals("Drums", res.get(0).getString(7)) assertEquals("John Butler", res.get(1).getString(4)) assertEquals("Guitar", res.get(1).getString(7)) assertEquals("John Mayer", res.get(2).getString(4)) assertEquals("Guitar", res.get(2).getString(7)) assertEquals("John Scofield", res.get(3).getString(4)) assertEquals("Guitar", res.get(3).getString(7)) } @Test def `should fail validating options if ErrorIfExists is used`(): Unit = { val musicDf = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ).toDF("experience", "name", "instrument") try { musicDf.repartition(1).write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.source.save.mode", "ErrorIfExists") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() } catch { case e: IllegalArgumentException => assertEquals("Save mode 'ErrorIfExists' is not supported on Spark 3.0, use 'Append' instead.", e.getMessage) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } @Test @Ignore("trying to recreate the deadlock issue") def `should give better errors if transaction fails`(): Unit = { val df = List.fill(200)(("John Bonham", "Drums")).toDF("name", "instrument") df.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() df.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("transaction.retries", 0) .option("partitions", "10") .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() } def writeKeyModeRelationshipWriteDataSet( optionModifier: Map[String, String] => Map[String, String] = { m => m } ): DataFrame = { val musicDf = Seq( (12, "John Bonham", "Drums", 2, true), (19, "John Mayer", "Guitar", 1, false), (32, "John Scofield", "Guitar", 3, true), (15, "John Butler", "Guitar", 4, false) ).toDF("experience", "name", "instrument", "rating", "hasDiploma") val options = Map( "url" -> SparkConnectorScalaSuiteIT.server.getBoltUrl, "relationship" -> "PLAYS", "relationship.source.save.mode" -> "Overwrite", "relationship.target.save.mode" -> "Overwrite", "relationship.save.strategy" -> "keys", "relationship.source.labels" -> ":Musician", "relationship.source.node.keys" -> "name", "relationship.target.labels" -> ":Instrument", "relationship.target.node.keys" -> "instrument:name" ) val modifiedOptions = optionModifier(options) musicDf.repartition(1).write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .options(modifiedOptions) .save() ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship", "PLAYS") .option("relationship.source.labels", ":Musician") .option("relationship.target.labels", ":Instrument") .load() } @Test def `should write relations with KEYS mode with explicitly listed properties`(): Unit = { val resultDf = writeKeyModeRelationshipWriteDataSet({ options => options + ("relationship.properties" -> "experience, rating:avgRating, instrument") }) resultDf.show(false) assertEquals(4, resultDf.count()) val res = resultDf.orderBy("`source.name`").collectAsList() assertEquals("John Bonham", getByName[String](res.get(0), "source.name")) assertEquals("Drums", getByName[String](res.get(0), "target.name")) assertEquals("Drums", getByName[String](res.get(0), "rel.instrument")) assertEquals(12, getByName[Long](res.get(0), "rel.experience")) assertEquals(2, getByName[Long](res.get(0), "rel.avgRating")) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(0).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(0).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(0).fieldIndex("rel.name")) assertEquals("John Butler", getByName[String](res.get(1), "source.name")) assertEquals("Guitar", getByName[String](res.get(1), "target.name")) assertEquals("Guitar", getByName[String](res.get(1), "rel.instrument")) assertEquals(15, getByName[Long](res.get(1), "rel.experience")) assertEquals(4, getByName[Long](res.get(1), "rel.avgRating")) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(1).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(1).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(1).fieldIndex("rel.name")) assertEquals("John Mayer", getByName[String](res.get(2), "source.name")) assertEquals("Guitar", getByName[String](res.get(2), "target.name")) assertEquals("Guitar", getByName[String](res.get(2), "rel.instrument")) assertEquals(19, getByName[Long](res.get(2), "rel.experience")) assertEquals(1, getByName[Long](res.get(2), "rel.avgRating")) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(2).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(2).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(2).fieldIndex("rel.name")) assertEquals("John Scofield", getByName[String](res.get(3), "source.name")) assertEquals("Guitar", getByName[String](res.get(3), "target.name")) assertEquals("Guitar", getByName[String](res.get(3), "rel.instrument")) assertEquals(32, getByName[Long](res.get(3), "rel.experience")) assertEquals(3, getByName[Long](res.get(3), "rel.avgRating")) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(3).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(3).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(3).fieldIndex("rel.name")) } @Test def `should write relations with KEYS mode with explicitly listed empty properties`(): Unit = { val resultDf = writeKeyModeRelationshipWriteDataSet({ options => options + ("relationship.properties" -> "") }) resultDf.show(false) assertEquals(4, resultDf.count()) val res = resultDf.orderBy("`source.name`").collectAsList() assertEquals("John Bonham", getByName[String](res.get(0), "source.name")) assertEquals("Drums", getByName[String](res.get(0), "target.name")) assertThrows[IllegalArgumentException]( "relationship should not have experience field", res.get(0).fieldIndex("rel.experience") ) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(0).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(0).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(0).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(0).fieldIndex("rel.instrument") ) assertEquals("John Butler", getByName[String](res.get(1), "source.name")) assertEquals("Guitar", getByName[String](res.get(1), "target.name")) assertThrows[IllegalArgumentException]( "relationship should not have experience field", res.get(1).fieldIndex("rel.experience") ) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(1).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(1).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(1).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(1).fieldIndex("rel.instrument") ) assertEquals("John Mayer", getByName[String](res.get(2), "source.name")) assertEquals("Guitar", getByName[String](res.get(2), "target.name")) assertThrows[IllegalArgumentException]( "relationship should not have experience field", res.get(2).fieldIndex("rel.experience") ) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(2).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(2).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(2).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(2).fieldIndex("rel.instrument") ) assertEquals("John Scofield", getByName[String](res.get(3), "source.name")) assertEquals("Guitar", getByName[String](res.get(3), "target.name")) assertThrows[IllegalArgumentException]( "relationship should not have experience field", res.get(3).fieldIndex("rel.experience") ) assertThrows[IllegalArgumentException]( "relationship should not have hasDiploma field", res.get(3).fieldIndex("rel.hasDiploma") ) assertThrows[IllegalArgumentException]( "relationship should not have rating field", res.get(3).fieldIndex("rel.rating") ) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(3).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(3).fieldIndex("rel.instrument") ) } @Test def `should write relations with KEYS mode with default properties`(): Unit = { val resultDf = writeKeyModeRelationshipWriteDataSet() resultDf.show(false) assertEquals(4, resultDf.count()) val res = resultDf.orderBy("`source.name`").collectAsList() assertEquals("John Bonham", getByName[String](res.get(0), "source.name")) assertEquals("Drums", getByName[String](res.get(0), "target.name")) assertEquals(12, getByName[Long](res.get(0), "rel.experience")) assertEquals(true, getByName[Boolean](res.get(0), "rel.hasDiploma")) assertEquals(2, getByName[Long](res.get(0), "rel.rating")) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(0).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(0).fieldIndex("rel.instrument") ) assertEquals("John Butler", getByName[String](res.get(1), "source.name")) assertEquals("Guitar", getByName[String](res.get(1), "target.name")) assertEquals(15, getByName[Long](res.get(1), "rel.experience")) assertEquals(false, getByName[Boolean](res.get(1), "rel.hasDiploma")) assertEquals(4, getByName[Long](res.get(1), "rel.rating")) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(1).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(1).fieldIndex("rel.instrument") ) assertEquals("John Mayer", getByName[String](res.get(2), "source.name")) assertEquals("Guitar", getByName[String](res.get(2), "target.name")) assertEquals(19, getByName[Long](res.get(2), "rel.experience")) assertEquals(false, getByName[Boolean](res.get(2), "rel.hasDiploma")) assertEquals(1, getByName[Long](res.get(2), "rel.rating")) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(2).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(2).fieldIndex("rel.instrument") ) assertEquals("John Scofield", getByName[String](res.get(3), "source.name")) assertEquals("Guitar", getByName[String](res.get(3), "target.name")) assertEquals(32, getByName[Long](res.get(3), "rel.experience")) assertEquals(true, getByName[Boolean](res.get(3), "rel.hasDiploma")) assertEquals(3, getByName[Long](res.get(3), "rel.rating")) assertThrows[IllegalArgumentException]("relationship should not have name field", res.get(3).fieldIndex("rel.name")) assertThrows[IllegalArgumentException]( "relationship should not have instrument field", res.get(3).fieldIndex("rel.instrument") ) } @Test def `should read and write relations with node overwrite mode`(): Unit = { val fixtureQuery: String = s"""CREATE (m:Musician {id: 1, name: "John Bonham"}) |CREATE (i:Instrument {name: "Drums"}) |CREATE (m)-[:PLAYS {experience: 10}]->(i) |RETURN * """.stripMargin SparkConnectorScalaSuiteIT.driver.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run(fixtureQuery).consume() } ) val musicDf = Seq( (1, 12, "John Henry Bonham", "Drums"), (2, 19, "John Mayer", "Guitar"), (3, 32, "John Scofield", "Guitar"), (4, 15, "John Butler", "Guitar") ).toDF("id", "experience", "name", "instrument") musicDf.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship.source.save.mode", "Overwrite") .option("relationship.target.save.mode", "Overwrite") .option("relationship", "PLAYS") .option("relationship.properties", "experience") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "id") .option("relationship.source.node.properties", "name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() val df2 = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship.nodes.map", "false") .option("relationship", "PLAYS") .option("relationship.source.labels", ":Musician") .option("relationship.target.labels", ":Instrument") .load() val result = df2.where("`source.id` = 1") .collectAsList().get(0) assertEquals(12, result.getLong(9)) assertEquals("John Henry Bonham", result.getString(4)) } @Test def `should insert index while insert nodes`(): Unit = { val ds = (1 to 10) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option("schema.optimization.type", "INDEX") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val indexCount = SparkConnectorScalaSuiteIT.session().run( getIndexQueryCount ) .single() .get("count") .asLong() assertEquals(1, indexCount) SparkConnectorScalaSuiteIT.session().run("DROP INDEX spark_INDEX_Person_surname") } private def getIndexQueryCount: String = { val (uniqueKey, uniqueCondition) = if (TestUtil.neo4jVersion(SparkConnectorScalaSuiteIT.session()) >= Versions.NEO4J_5) { ("owningConstraint", "owningConstraint IS NULL") } else { ("uniqueness", "uniqueness = 'NONUNIQUE'") } s"""SHOW INDEXES YIELD labelsOrTypes, properties, $uniqueKey |WHERE labelsOrTypes = ['Person'] AND properties = ['surname'] AND $uniqueCondition |RETURN count(*) AS count |""".stripMargin } private def getConstraintQueryCount: String = { val (uniqueKey, uniqueCondition) = if (TestUtil.neo4jVersion(SparkConnectorScalaSuiteIT.session()) >= Versions.NEO4J_5) { ("owningConstraint", "owningConstraint IS NOT NULL") } else { ("uniqueness", "uniqueness = 'UNIQUE'") } s"""SHOW INDEXES YIELD labelsOrTypes, properties, $uniqueKey |WHERE labelsOrTypes = ['Person'] AND properties = ['surname'] AND $uniqueCondition |RETURN count(*) AS count |""".stripMargin } @Test def `should create constraint when insert nodes`(): Unit = { val ds = (1 to 10) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option("schema.optimization.type", "NODE_CONSTRAINTS") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val constraintCount = SparkConnectorScalaSuiteIT.session().run( getConstraintQueryCount ) .single() .get("count") .asLong() assertEquals(1, constraintCount) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT spark_NODE_CONSTRAINTS_Person_surname") } @Test def `should not create constraint when insert nodes because they already exist`(): Unit = { SparkConnectorScalaSuiteIT.session().run( "CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE (p.surname) IS UNIQUE" ) val ds = (1 to 10) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option("schema.optimization.type", "NODE_CONSTRAINTS") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val constraintCount = SparkConnectorScalaSuiteIT.session().run( getConstraintQueryCount ) .single() .get("count") .asLong() assertEquals(1, constraintCount) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT person_surname") } @Test def `should insert indexes while insert with query`(): Unit = { val ds = (1 to 10) .map(i => i.toString) .toDF("surname") ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Overwrite) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":Person:Customer") .option("node.keys", "surname") .option("schema.optimization.type", "INDEX") .save() ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("query", "CREATE (n:MyNode{fullName: event.name + event.surname, age: event.age - 10})") .option("batch.size", "11") .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person:Customer) |RETURN p.surname AS surname |""".stripMargin ).list().asScala .map(r => r.asMap().asScala) .toSet val expected = ds.collect().map(row => Map("surname" -> row.getAs[String]("surname"))) .toSet assertEquals(expected, records) val indexCount = SparkConnectorScalaSuiteIT.session() .run(getIndexQueryCount) .single() .get("count") .asLong() assertEquals(1, indexCount) SparkConnectorScalaSuiteIT.session().run("DROP INDEX spark_INDEX_Person_surname") } @Test def `should manage script passing the data to the executors`(): Unit = { val ds = Seq(SimplePerson("Andrea", "Santurbano"), SimplePerson("Davide", "Fantuzzi")).toDS() .repartition(2) ds.write .format(classOf[DataSource].getName) .mode(SaveMode.Append) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", "CREATE (n:Person{fullName: event.name + ' ' + event.surname, age: scriptResult[0].age[event.name]})" ) .option( "script", """CREATE INDEX person_surname FOR (p:Person) ON (p.surname); |CREATE CONSTRAINT product_name_sku FOR (p:Product) | REQUIRE (p.name, p.sku) | IS NODE KEY; |RETURN {Andrea: 36, Davide: 32} AS age; |""".stripMargin ) .save() val records = SparkConnectorScalaSuiteIT.session().run( """MATCH (p:Person) |WHERE (p.fullName = 'Andrea Santurbano' AND p.age = 36) |OR (p.fullName = 'Davide Fantuzzi' AND p.age = 32) |RETURN count(p) AS count |""".stripMargin ) .single() .get("count") .asLong() val expected = ds.count assertEquals(expected, records) val uniqueFieldName = if (TestUtil.neo4jVersion(SparkConnectorScalaSuiteIT.session()) >= Versions.NEO4J_5) "owningConstraint" else "uniqueness" val (indexCondition, uniqueCondition) = if (TestUtil.neo4jVersion(SparkConnectorScalaSuiteIT.session()) >= Versions.NEO4J_5) { (s"$uniqueFieldName IS NULL", s"$uniqueFieldName IS NOT NULL") } else { (s"$uniqueFieldName = 'NONUNIQUE'", s"$uniqueFieldName = 'UNIQUE'") } val query = s"""SHOW INDEXES YIELD labelsOrTypes, properties, $uniqueFieldName |WHERE (labelsOrTypes = ['Person'] AND properties = ['surname'] AND $indexCondition) |OR (labelsOrTypes = ['Product'] AND properties = ['name', 'sku'] AND $uniqueCondition) |RETURN count(*) AS count |""".stripMargin val constraintCount = SparkConnectorScalaSuiteIT.session() .run(query) .single() .get("count") .asLong() assertEquals(2, constraintCount) SparkConnectorScalaSuiteIT.session().run("DROP INDEX person_surname") SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT product_name_sku") } @Test def `should work create source node and match target node`() { val data = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ) SparkConnectorScalaSuiteIT.session().run("CREATE " + data .map(_._3) .toSet[String] .map(instrument => s"(:Instrument{name: '$instrument'})") .mkString(", ")) val musicDf = data.toDF("experience", "name", "instrument") musicDf.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name") .option("relationship.target.save.mode", "match") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save val count = SparkConnectorScalaSuiteIT.session().run( """MATCH p = (:Musician)-[:PLAYS]->(:Instrument) |RETURN count(p) AS count""".stripMargin ) .single() .get("count") .asLong() assertEquals(data.size, count) } @Test def `should work match source node and merge target node`() { SparkConnectorScalaSuiteIT.session().run( "CREATE CONSTRAINT musician_name FOR (m:Musician) REQUIRE (m.name) IS UNIQUE" ) val data = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ) SparkConnectorScalaSuiteIT.session().run("CREATE " + data .map(_._2) .toSet[String] .map(name => s"(:Musician{name: '$name'})") .mkString(", ")) val musicDf = data.toDF("experience", "name", "instrument") musicDf.repartition(1).write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "match") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name") .option("relationship.target.save.mode", "overwrite") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save val count = SparkConnectorScalaSuiteIT.session().run( """MATCH p = (:Musician)-[:PLAYS]->(:Instrument) |RETURN count(p) AS count""".stripMargin ) .single() .get("count") .asLong() assertEquals(data.size, count) SparkConnectorScalaSuiteIT.session().run("DROP CONSTRAINT musician_name") } @Test def `should work match source node and merge target node with odd chars`() { val data = Seq( (12, "John Bonham", "Drums"), (19, "John Mayer", "Guitar"), (32, "John Scofield", "Guitar"), (15, "John Butler", "Guitar") ) val musicDf = data.toDF("experience", "who:name", "instrument") musicDf.repartition(1).write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "PLAYS") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "overwrite") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "`who:name`") .option("relationship.target.save.mode", "overwrite") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save val count = SparkConnectorScalaSuiteIT.session().run( """MATCH p = (:Musician)-[:PLAYS]->(:Instrument) |RETURN count(p) AS count""".stripMargin ) .single() .get("count") .asLong() assertEquals(data.size, count) } @Test def shouldWriteComplexDF(): Unit = { val data = Seq( ( "Cuba Gooding Jr.", 1, "2022-06-07 00:00:00", Seq(Map("product_id" -> 1, "quantity" -> 2), Map("product_id" -> 2, "quantity" -> 4)) ), ( "Tom Hanks", 2, "2022-07-07 00:00:00", Seq(Map("product_id" -> 11, "quantity" -> 2), Map("product_id" -> 22, "quantity" -> 4)) ) ).toDF("actor_name", "order_id", "order_date", "products") data.write .mode(SaveMode.Overwrite) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option( "query", """ |MERGE (person:Person {name: event.actor_name}) |CREATE (order:Order {id: event.order_id, date: datetime(replace(event.order_date, ' ', 'T'))}) |MERGE (person)-[:CREATED]->(order) |WITH event, person, order |UNWIND event.products AS product_order |MERGE (product:Product {id: product_order.product_id}) |CREATE (order)-[:CONTAINS{quantityOrdered: product_order.quantity}]->(product) |""".stripMargin ) .save() val actual = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (p:Person)-[cr:CREATED]->(o:Order)-[co:CONTAINS]->(pr:Product) |WITH p, pr, o, co |ORDER BY p.name, pr.id |RETURN p.name AS name, o.id AS order, collect({id: pr.id, quantity: co.quantityOrdered}) AS products |""".stripMargin ) .list() .asScala .map(_.asMap()) .toSet .asJava val expected = Set( Map( "name" -> "Cuba Gooding Jr.", "order" -> 1L, "products" -> List( Map("id" -> 1L, "quantity" -> 2L).asJava, Map("id" -> 2L, "quantity" -> 4L).asJava ).asJava ).asJava, Map( "name" -> "Tom Hanks", "order" -> 2L, "products" -> List( Map("id" -> 11L, "quantity" -> 2L).asJava, Map("id" -> 22L, "quantity" -> 4L).asJava ).asJava ).asJava ).asJava assertEquals(expected, actual) } @Test def shouldFix502(): Unit = { val data = Seq( ("Foo", 1, Map("inner" -> Map("key" -> "innerValue"))), ("Bar", 1, Map("inner" -> Map("key" -> "innerValue1"))) ).toDF("id", "time", "table") data.write .mode(SaveMode.Append) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNodeWithMapFlattend") .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (n:MyNodeWithMapFlattend) |WHERE ( | properties(n) = {id: 'Foo', time: 1, `table.inner.key`: 'innerValue'} | OR properties(n) = {id: 'Bar', time: 1, `table.inner.key`: 'innerValue1'} |) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() junit.Assert.assertEquals(2L, count) } @Test def shouldFix502WithCollisions(): Unit = { val data = Seq( ("Foo", 1, ListMap("key.inner" -> Map("key" -> "innerValue"), "key" -> Map("inner.key" -> "value"))), ("Bar", 1, ListMap("key.inner" -> Map("key" -> "innerValue1"), "key" -> Map("inner.key" -> "value1"))) ).toDF("id", "time", "table") data.write .mode(SaveMode.Append) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNodeWithMapFlattend") .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (n:MyNodeWithMapFlattend) |WHERE ( | properties(n) = {id: 'Foo', time: 1, `table.key.inner.key`: 'value'} | OR properties(n) = {id: 'Bar', time: 1, `table.key.inner.key`: 'value1'} |) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() junit.Assert.assertEquals(2L, count) } @Test def shouldFix502WithCollisionsAndAggregateValues(): Unit = { val data = Seq( ("Foo", 1, ListMap("key.inner" -> Map("key" -> "innerValue"), "key" -> Map("inner.key" -> "value"))), ("Bar", 1, ListMap("key.inner" -> Map("key" -> "innerValue1"), "key" -> Map("inner.key" -> "value1"))) ).toDF("id", "time", "table") data.write .mode(SaveMode.Append) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("labels", ":MyNodeWithMapFlattend") .option("schema.map.group.duplicate.keys", true) .save() val count: Long = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (n:MyNodeWithMapFlattend) |WHERE ( | properties(n) = {id: 'Foo', time: 1, `table.key.inner.key`: ['innerValue', 'value']} | OR properties(n) = {id: 'Bar', time: 1, `table.key.inner.key`: ['innerValue1', 'value1']} |) |RETURN count(n) |""".stripMargin ) .single() .get(0) .asLong() junit.Assert.assertEquals(2L, count) } @Test def doesNotWriteNodePropertiesToRelationship(): Unit = { val data = Seq( ("john", "The Matrix", "today"), ("jane", "Oppenheimer", "yesterday"), ("şaban", "Hababam Sınıfı", "two days ago") ).toDF("username", "movie_title", "watch_time") data.write .mode(SaveMode.Append) .format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) .option("relationship", "WATCHED") .option("relationship.save.strategy", "keys") .option("relationship.source.save.mode", "Overwrite") .option("relationship.source.labels", ":User") .option("relationship.source.node.keys", "username:name") .option("relationship.target.save.mode", "Overwrite") .option("relationship.target.labels", ":Movie") .option("relationship.target.node.keys", "movie_title:title") .save() val rows = SparkConnectorScalaSuiteIT.session().run( """ |MATCH (:User)-[r:WATCHED]->(:Movie) |WITH r |ORDER BY r.watch_time ASC |RETURN collect(r{.*}) |""".stripMargin ) .single() .get(0) .asList((value: Value) => value.asMap().asScala) .asScala junit.Assert.assertEquals( List( Map("watch_time" -> "today"), Map("watch_time" -> "two days ago"), Map("watch_time" -> "yesterday") ), rows ) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/DefaultConfigTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.Assert.assertEquals import org.junit.Test import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary class DefaultConfigTSE extends SparkConnectorScalaBaseTSE { @Test def `when session has default parameters it should use those instead of requiring options`(): Unit = { SparkConnectorScalaSuiteIT.session() .writeTransaction( new TransactionWork[ResultSummary] { override def execute(tx: Transaction): ResultSummary = tx.run("CREATE (p:Person {name: 'Foobar'})").consume() } ) ss.conf.set("neo4j.url", SparkConnectorScalaSuiteIT.server.getBoltUrl) val df = ss.read.format(classOf[DataSource].getName) .option("labels", "Person") .load() assertEquals(df.count(), 1) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/GraphDataScienceIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Assert.fail import org.junit.Assume import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.driver.Transaction import scala.math.Ordering.Implicits.infixOrderingOps class GraphDataScienceIT extends SparkConnectorScalaSuiteWithGdsBase { @After def cleanData(): Unit = { use(SparkConnectorScalaSuiteWithGdsBase.session("system")) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds").consume() } use(SparkConnectorScalaSuiteWithGdsBase.session()) { session => session .writeTransaction((tx: Transaction) => { tx.run( """ |CALL gds.graph.list() YIELD graphName |WITH graphName AS g |CALL gds.graph.drop(g) YIELD graphName |RETURN * |""".stripMargin ).consume() }) } } @Test def shouldReturnThePageRank(): Unit = { initForPageRank() val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.pageRank.stream") .option("gds.graphName", "myGraph") .option("gds.configuration.concurrency", "2") .load() assertEquals(df.count(), 8) df.show(false) assertEquals(StructType(Array(StructField("nodeId", LongType), StructField("score", DoubleType))), df.schema) val dfEstimate = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.pageRank.stream.estimate") .option("gds.graphNameOrConfiguration", "myGraph") .option("gds.algoConfiguration.concurrency", "2") .load() assertEquals(dfEstimate.count(), 1) dfEstimate.show(false) assertEquals( StructType( Array( StructField("requiredMemory", StringType), StructField("treeView", StringType), StructField("mapView", MapType(StringType, StringType)), StructField("bytesMin", LongType), StructField("bytesMax", LongType), StructField("nodeCount", LongType), StructField("relationshipCount", LongType), StructField("heapPercentageMin", DoubleType), StructField("heapPercentageMax", DoubleType) ) ), dfEstimate.schema ) } @Test def shouldFailWithUnsupportedOptions(): Unit = { initForPageRank() def run(options: Map[String, String], error: String): Unit = { try { ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .options(options) .load() .show(false) fail("Expected to throw an exception") } catch { case iae: IllegalArgumentException => assertTrue(iae.getMessage.equals(error)) case _: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}") } } run( Map( "gds" -> "gds.pageRank.stream", "gds.graphName" -> "myGraph", "gds.configuration.concurrency" -> "2", "partitions" -> "2" ), "For GDS queries we support only one partition" ) run( Map( "gds" -> "gds.pageRank.write", "gds.graphName" -> "myGraph", "gds.configuration.concurrency" -> "2" ), "You cannot execute GDS mutate or write procedure in a read query" ) run( Map( "gds" -> "gds.pageRank.mutate", "gds.graphName" -> "myGraph", "gds.configuration.concurrency" -> "2" ), "You cannot execute GDS mutate or write procedure in a read query" ) } @Test def shouldWorkWithMapReturn(): Unit = { initForHits() val procName = if (TestUtil.gdsVersion(SparkConnectorScalaSuiteWithGdsBase.session()) >= Versions.GDS_2_5) "gds.hits.stream" else "gds.alpha.hits.stream" val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", procName) .option("gds.graphName", "myGraph") .option("gds.configuration.hitsIterations", "20") .load() assertEquals(df.count(), 9) df.show(false) assertEquals( StructType(Array(StructField("nodeId", LongType), StructField("values", MapType(StringType, StringType)))), df.schema ) } @Test def shouldWorkWithPathReturn(): Unit = { initForYens() val sourceTargetNodes = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("labels", "Location") .load() .where("name IN ('A', 'F')") .orderBy("name") .collect() val (sourceId, targetId) = (sourceTargetNodes(0).getAs[Long](""), sourceTargetNodes(1).getAs[Long]("")) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.shortestPath.yens.stream") .option("gds.graphName", "myGraph") .option("gds.configuration.sourceNode", sourceId) .option("gds.configuration.targetNode", targetId) .option("gds.configuration.k", 3) .option("gds.configuration.relationshipWeightProperty", "cost") .load() assertEquals(df.count(), 3) df.show(false) assertEquals( StructType( Array( StructField("index", LongType), StructField("sourceNode", LongType), StructField("targetNode", LongType), StructField("totalCost", DoubleType), StructField("nodeIds", ArrayType(LongType)), StructField("costs", ArrayType(DoubleType)), StructField("path", StringType) ) ), df.schema ) val (graphNameParam, algoConfigurationParam) = if (TestUtil.gdsVersion(SparkConnectorScalaSuiteWithGdsBase.session()) >= Versions.GDS_2_4) ("graphName", "configuration") else ("graphNameOrConfiguration", "algoConfiguration") val dfEstimate = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.shortestPath.yens.stream.estimate") .option(s"gds.$graphNameParam", "myGraph") .option(s"gds.$algoConfigurationParam.sourceNode", sourceId) .option(s"gds.$algoConfigurationParam.targetNode", targetId) .option(s"gds.$algoConfigurationParam.k", 3) .option(s"gds.$algoConfigurationParam.relationshipWeightProperty", "cost") .load() assertEquals(dfEstimate.count(), 1) dfEstimate.show(false) assertEquals( StructType( Array( StructField("requiredMemory", StringType), StructField("treeView", StringType), StructField("mapView", MapType(StringType, StringType)), StructField("bytesMin", LongType), StructField("bytesMax", LongType), StructField("nodeCount", LongType), StructField("relationshipCount", LongType), StructField("heapPercentageMin", DoubleType), StructField("heapPercentageMax", DoubleType) ) ), dfEstimate.schema ) } private def initForYens(): Unit = { SparkConnectorScalaSuiteWithGdsBase.session() .writeTransaction((tx: Transaction) => { tx.run( """ |CREATE (a:Location {name: 'A'}), | (b:Location {name: 'B'}), | (c:Location {name: 'C'}), | (d:Location {name: 'D'}), | (e:Location {name: 'E'}), | (f:Location {name: 'F'}), | (a)-[:ROAD {cost: 50}]->(b), | (a)-[:ROAD {cost: 50}]->(c), | (a)-[:ROAD {cost: 100}]->(d), | (b)-[:ROAD {cost: 40}]->(d), | (c)-[:ROAD {cost: 40}]->(d), | (c)-[:ROAD {cost: 80}]->(e), | (d)-[:ROAD {cost: 30}]->(e), | (d)-[:ROAD {cost: 80}]->(f), | (e)-[:ROAD {cost: 40}]->(f); |""".stripMargin ).consume() }) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.graph.project") .option("gds.graphName", "myGraph") .option("gds.nodeProjection", "Location") .option("gds.relationshipProjection", "ROAD") .option("gds.configuration.relationshipProperties", "cost") .load() .show(false) } @Test def shouldWorkWithKNearest(): Unit = { SparkConnectorScalaSuiteWithGdsBase.session() .writeTransaction((tx: Transaction) => { tx.run( """ |CREATE (alice:Person {name: 'Alice', age: 24, lotteryNumbers: [1, 3], embedding: [1.0, 3.0]}) |CREATE (bob:Person {name: 'Bob', age: 73, lotteryNumbers: [1, 2, 3], embedding: [2.1, 1.6]}) |CREATE (carol:Person {name: 'Carol', age: 24, lotteryNumbers: [3], embedding: [1.5, 3.1]}) |CREATE (dave:Person {name: 'Dave', age: 48, lotteryNumbers: [2, 4], embedding: [0.6, 0.2]}) |CREATE (eve:Person {name: 'Eve', age: 67, lotteryNumbers: [1, 5], embedding: [1.8, 2.7]}); |""".stripMargin ).consume() }) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.graph.project") .option("gds.graphName", "myGraph") .option("gds.nodeProjection.Person.properties", "['age','lotteryNumbers','embedding']") .option("gds.relationshipProjection", "*") .load() .show(false) val df = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.knn.stream") .option("gds.graphName", "myGraph") .option("gds.configuration.topK", 1) .option("gds.configuration.nodeProperties", "['age']") .option("gds.configuration.randomSeed", 1337) .option("gds.configuration.concurrency", 1) .option("gds.configuration.sampleRate", 1.0) .option("gds.configuration.deltaThreshold", 0.0) .load() assertEquals(df.count(), 5) df.show(false) assertEquals( StructType( Array( StructField("node1", LongType), StructField("node2", LongType), StructField("similarity", DoubleType) ) ), df.schema ) val dfEstimate = ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.knn.stream.estimate") .option("gds.graphNameOrConfiguration", "myGraph") .option("gds.algoConfiguration.topK", 1) .option("gds.algoConfiguration.nodeProperties", "['age']") .option("gds.algoConfiguration.randomSeed", 1337) .option("gds.algoConfiguration.concurrency", 1) .option("gds.algoConfiguration.sampleRate", 1.0) .option("gds.algoConfiguration.deltaThreshold", 0.0) .load() assertEquals(dfEstimate.count(), 1) dfEstimate.show(false) assertEquals( StructType( Array( StructField("requiredMemory", StringType), StructField("treeView", StringType), StructField("mapView", MapType(StringType, StringType)), StructField("bytesMin", LongType), StructField("bytesMax", LongType), StructField("nodeCount", LongType), StructField("relationshipCount", LongType), StructField("heapPercentageMin", DoubleType), StructField("heapPercentageMax", DoubleType) ) ), dfEstimate.schema ) } private def initForPageRank(): Unit = { SparkConnectorScalaSuiteWithGdsBase.session() .writeTransaction((tx: Transaction) => { tx.run( """ |CREATE | (home:Page {name:'Home'}), | (about:Page {name:'About'}), | (product:Page {name:'Product'}), | (links:Page {name:'Links'}), | (a:Page {name:'Site A'}), | (b:Page {name:'Site B'}), | (c:Page {name:'Site C'}), | (d:Page {name:'Site D'}), | | (home)-[:LINKS {weight: 0.2}]->(about), | (home)-[:LINKS {weight: 0.2}]->(links), | (home)-[:LINKS {weight: 0.6}]->(product), | (about)-[:LINKS {weight: 1.0}]->(home), | (product)-[:LINKS {weight: 1.0}]->(home), | (a)-[:LINKS {weight: 1.0}]->(home), | (b)-[:LINKS {weight: 1.0}]->(home), | (c)-[:LINKS {weight: 1.0}]->(home), | (d)-[:LINKS {weight: 1.0}]->(home), | (links)-[:LINKS {weight: 0.8}]->(home), | (links)-[:LINKS {weight: 0.05}]->(a), | (links)-[:LINKS {weight: 0.05}]->(b), | (links)-[:LINKS {weight: 0.05}]->(c), | (links)-[:LINKS {weight: 0.05}]->(d); |""".stripMargin ).consume() }) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.graph.project") .option("gds.graphName", "myGraph") .option("gds.nodeProjection", "Page") .option("gds.relationshipProjection", "LINKS") .option("gds.configuration.relationshipProperties", "weight") .load() .show(false) } private def initForHits(): Unit = { Assume.assumeTrue(TestUtil.neo4jVersion(SparkConnectorScalaSuiteWithGdsBase.session()) >= Versions.NEO4J_5) SparkConnectorScalaSuiteWithGdsBase.session() .writeTransaction((tx: Transaction) => { tx.run( """ CREATE | (a:Website {name: 'A'}), | (b:Website {name: 'B'}), | (c:Website {name: 'C'}), | (d:Website {name: 'D'}), | (e:Website {name: 'E'}), | (f:Website {name: 'F'}), | (g:Website {name: 'G'}), | (h:Website {name: 'H'}), | (i:Website {name: 'I'}), | | (a)-[:LINK]->(b), | (a)-[:LINK]->(c), | (a)-[:LINK]->(d), | (b)-[:LINK]->(c), | (b)-[:LINK]->(d), | (c)-[:LINK]->(d), | | (e)-[:LINK]->(b), | (e)-[:LINK]->(d), | (e)-[:LINK]->(f), | (e)-[:LINK]->(h), | | (f)-[:LINK]->(g), | (f)-[:LINK]->(i), | (f)-[:LINK]->(h), | (g)-[:LINK]->(h), | (g)-[:LINK]->(i), | (h)-[:LINK]->(i); |""".stripMargin ).consume() }) ss.read.format(classOf[DataSource].getName) .option("url", SparkConnectorScalaSuiteWithGdsBase.server.getBoltUrl) .option("gds", "gds.graph.project") .option("gds.graphName", "myGraph") .option("gds.nodeProjection", "Website") .option("gds.relationshipProjection.LINK.indexInverse", "true") .load() .show(false) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.AfterClass import org.junit.Assert.assertEquals import org.junit.BeforeClass import org.junit.Test import org.neo4j.Neo4jContainerExtension import org.neo4j.driver.AuthTokens import org.neo4j.driver.Driver import org.neo4j.driver.GraphDatabase import org.neo4j.spark.ReauthenticationIT.KEYCLOAK import org.neo4j.spark.ReauthenticationIT.NEO4J import org.neo4j.spark.SparkConnectorScalaSuiteIT.ss import org.slf4j.LoggerFactory import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.Network import org.testcontainers.containers.output.Slf4jLogConsumer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.containers.wait.strategy.WaitAllStrategy import org.testcontainers.utility.MountableFile import java.time.Duration.ofMinutes object ReauthenticationIT { private val log = LoggerFactory.getLogger(classOf[ReauthenticationIT]) private class TestKeycloakContainer(image: String) extends GenericContainer[TestKeycloakContainer](image) { def getHttpPort: Integer = this.getMappedPort(8080) } private val NETWORK = Network.newNetwork private val KEYCLOAK = new TestKeycloakContainer("quay.io/keycloak/keycloak:26.2.5") .withNetwork(NETWORK) .withEnv("KC_BOOTSTRAP_ADMIN_USERNAME", "admin") .withEnv("KC_BOOTSTRAP_ADMIN_PASSWORD", "admin") .withNetworkAliases("keycloak") .withExposedPorts(8080, 9000, 8443) .withCopyFileToContainer( MountableFile.forClasspathResource("/neo4j-keycloak.jks"), "/opt/keycloak/conf/server.keystore" ) .withEnv("KC_HTTPS_KEY_STORE_FILE", "/opt/keycloak/conf/server.keystore") .withEnv("KC_HTTPS_KEY_STORE_PASSWORD", "testpwd") .withEnv("KC_HTTPS_KEY_PASSWORD", "testpwd") .withEnv("KC_HEALTH_ENABLED", "true") .withCopyFileToContainer( MountableFile.forClasspathResource("/neo4j-sso-test-realm.json"), "/opt/keycloak/data/import/neo4j-sso-test-realm.json" ) .withEnv("KC_HOSTNAME", "https://keycloak:8443") .withEnv("KC_HOSTNAME_BACKCHANNEL_DYNAMIC", "true") .waitingFor( new WaitAllStrategy(WaitAllStrategy.Mode.WITH_INDIVIDUAL_TIMEOUTS_ONLY) .withStrategy(Wait.forListeningPort().withStartupTimeout(ofMinutes(2))) .withStrategy( Wait.forHttp("/health/started") .forPort(9000) .usingTls() .allowInsecure() .forStatusCode(200) .withStartupTimeout(java.time.Duration.ofMinutes(5)) ) ) .withStartupAttempts(3) .withLogConsumer(new Slf4jLogConsumer(log)) .withCommand("start-dev --import-realm") private val NEO4J = new Neo4jContainerExtension() .withNetwork(NETWORK) .withEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") .withCopyFileToContainer(MountableFile.forClasspathResource("/neo4j-keycloak.jks"), "/tmp/keycloak.jks") .withEnv( "_JAVA_OPTIONS", "-Djavax.net.ssl.keyStore=/tmp/keycloak.jks -Djavax.net.ssl.keyStorePassword=testpwd -Djavax.net.ssl.trustStore=/tmp/keycloak.jks -Djavax.net.ssl.trustStorePassword=testpwd" ) .withNeo4jConfig("dbms.security.authentication_providers", "oidc-keycloak,native") .withNeo4jConfig("dbms.security.authorization_providers", "oidc-keycloak,native") .withNeo4jConfig("dbms.security.oidc.keycloak.display_name", "Keycloak") .withNeo4jConfig("dbms.security.oidc.keycloak.auth_flow", "pkce") .withNeo4jConfig( "dbms.security.oidc.keycloak.well_known_discovery_uri", "https://keycloak:8443/realms/neo4j-sso-test/.well-known/openid-configuration" ) .withNeo4jConfig( "dbms.security.oidc.keycloak.params", "client_id=neo4j-commons-client;response_type=code;scope=openid email roles" ) .withNeo4jConfig("dbms.security.oidc.keycloak.audience", "account") .withNeo4jConfig("dbms.security.oidc.keycloak.issuer", "https://keycloak:8443/realms/neo4j-sso-test") .withNeo4jConfig("dbms.security.oidc.keycloak.client_id", "neo4j-commons-client") .withNeo4jConfig("dbms.security.oidc.keycloak.claims.username", "preferred_username") .withNeo4jConfig("dbms.security.oidc.keycloak.claims.groups", "groups") .withNeo4jConfig("dbms.security.auth_cache_ttl", "1s") @BeforeClass def setUp(): Unit = { KEYCLOAK.start() NEO4J.start() } @AfterClass def tearDown() = { TestUtil.closeSafely(NEO4J) TestUtil.closeSafely(KEYCLOAK) TestUtil.closeSafely(NETWORK) } } class ReauthenticationIT extends SparkConnectorScalaSuiteIT { @Test def createAnInstanceOfReAuthDriver(): Unit = { val options = Map( "url" -> NEO4J.getBoltUrl, "authentication.type" -> "keycloak", "authentication.keycloak.username" -> "john-tester", "authentication.keycloak.password" -> "testerpwd", "authentication.keycloak.authServerUrl" -> s"http://${KEYCLOAK.getHost}:${KEYCLOAK.getHttpPort}", "authentication.keycloak.realm" -> "neo4j-sso-test", "authentication.keycloak.clientId" -> "neo4j-commons-client", "authentication.keycloak.clientSecret" -> "QNrSpbh0mxhnlYlI21UcBaz3Htb734vi" ) var driver: Driver = null try { driver = GraphDatabase.driver(NEO4J.getBoltUrl, AuthTokens.basic("neo4j", NEO4J.getAdminPassword)) driver.session().run(" CREATE (n:Test {field: 42}) CREATE (t:Test {field: 45})").consume() } finally { driver.close() } val df = ss.read.format(classOf[DataSource].getName) .options(options) .option("query", "MATCH (t:Test {field: 42}) RETURN t.field") .load() .toDF() assertEquals(42, df.first().getLong(0)) Thread.sleep(4000) val df2 = ss.read.format(classOf[DataSource].getName) .options(options) .option("query", "MATCH (t:Test {field: 45}) RETURN t.field") .load() .toDF() assertEquals(45, df2.first().getLong(0)) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/SparkConnector30ScalaSuiteIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.runner.RunWith import org.junit.runners.Suite @RunWith(classOf[Suite]) @Suite.SuiteClasses(Array( classOf[DataSourceReaderTSE], classOf[DataSourceReaderNeo4jTSE], classOf[DataSourceWriterNeo4jTSE], classOf[DataSourceWriterTSE], classOf[DataSourceSchemaWriterTSE], classOf[DefaultConfigTSE], classOf[DataSourceStreamingReaderTSE], classOf[DataSourceStreamingWriterTSE], classOf[DataSourceReaderAggregationTSE] )) class SparkConnector30ScalaSuiteIT extends SparkConnectorScalaSuiteIT {} ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/SparkConnector30ScalaSuiteWithApocIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.runner.RunWith import org.junit.runners.Suite @RunWith(classOf[Suite]) @Suite.SuiteClasses(Array( classOf[DataSourceReaderWithApocTSE], classOf[DataSourceReaderNeo4jWithApocTSE], classOf[DataSourceReaderAggregationTSE] )) class SparkConnector30ScalaSuiteWithApocIT extends SparkConnectorScalaSuiteWithApocIT {} ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/SparkConnectorAuraTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assert._ import org.junit.Assume.assumeTrue import org.junit.Before import org.junit.BeforeClass import org.junit.Test import org.neo4j.Closeables.use import org.neo4j.driver._ import org.neo4j.spark.SparkConnectorAuraTest._ object SparkConnectorAuraTest { private var neo4j: Driver = _ private val username: Option[String] = Option[String](System.getenv("AURA_USER")) private val password: Option[String] = Option[String](System.getenv("AURA_PASSWORD")) private val url: Option[String] = Option[String](System.getenv("AURA_URI")) var sparkSession: SparkSession = _ @BeforeClass def setUpClass(): Unit = { assumeTrue(username.isDefined) assumeTrue(password.isDefined) assumeTrue(url.isDefined) sparkSession = SparkSession.builder() .config(new SparkConf() .setAppName("neoTest") .setMaster("local[*]") .set("spark.driver.host", "127.0.0.1")) .getOrCreate() neo4j = GraphDatabase.driver(url.get, AuthTokens.basic(username.get, password.get)) } @AfterClass def tearDown(): Unit = { TestUtil.closeSafely(neo4j) TestUtil.closeSafely(sparkSession) } } class SparkConnectorAuraTest { val ss: SparkSession = SparkSession.builder().getOrCreate() import ss.implicits._ @Before def setUp(): Unit = { use(neo4j.session(SessionConfig.forDatabase("system"))) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds").consume() } } @Test def shouldWriteToAndReadFromAura(): Unit = { val df = Seq(("John Bonham", "Drums", 12), ("John Mayer", "Guitar", 8)) .toDF("name", "instrument", "experience") df.write .mode("Overwrite") .format(classOf[DataSource].getName) .option("url", url.get) .option("authentication.type", "basic") .option("authentication.basic.username", username.get) .option("authentication.basic.password", password.get) .option("relationship", "PLAYS") .option("relationship.source.save.mode", "Append") .option("relationship.target.save.mode", "Append") .option("relationship.save.strategy", "keys") .option("relationship.source.labels", ":Musician") .option("relationship.source.node.keys", "name:name") .option("relationship.target.labels", ":Instrument") .option("relationship.target.node.keys", "instrument:name") .save() val results = sparkSession.read.format(classOf[DataSource].getName) .option("url", url.get) .option("authentication.type", "basic") .option("authentication.basic.username", username.get) .option("authentication.basic.password", password.get) .option("labels", "Musician") .load() .collectAsList() assertEquals(2, results.size()) } } ================================================ FILE: spark-3/src/test/scala/org/neo4j/spark/TransactionTimeoutIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assert.assertEquals import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.BeforeClass import org.junit.Test import org.neo4j.Neo4jContainerExtension import org.neo4j.driver.exceptions.ClientException import org.neo4j.spark.SparkConnectorScalaSuiteWithApocIT.conf import org.neo4j.spark.SparkConnectorScalaSuiteWithApocIT.server import org.neo4j.spark.SparkConnectorScalaSuiteWithApocIT.ss import org.neo4j.spark.TransactionTimeoutIT.NEO4J_LOW_TX_TIMEOUT import java.util.TimeZone class TransactionTimeoutIT extends SparkConnectorScalaSuiteWithApocIT { @Test def sparkConnectorRespectsTransactionTimeout(): Unit = { val cypher = "UNWIND range(1, 3) AS i " + "CALL apoc.util.sleep(1000) " + "RETURN i as number" val df = ss.read.format("org.neo4j.spark.DataSource") .option("url", server.getBoltUrl) .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", server.getAdminPassword) .option("db.transaction.timeout", "4000") .option("query", cypher) .load() .toDF() val results = df.select("number").rdd.map(_.getLong(0)).collect().toList val expected = Range.inclusive(1, 3).map(_.toLong).toList assertEquals(expected, results) } @Test def sparkConnectorFailsWithTransactionTimeoutWhenSetOnSessionLevel(): Unit = { val newConf = conf.clone().set("neo4j.url", server.getBoltUrl) .set("neo4j.authentication.basic.username", "neo4j") .set("neo4j.authentication.basic.password", server.getAdminPassword) .set("neo4j.db.transaction.timeout", "1000") val session = SparkSession.builder.config(newConf).getOrCreate() val cypher = "UNWIND range(1, 20) AS i " + "CALL apoc.util.sleep(1000) " + "RETURN i as number" val df = session.read.format("org.neo4j.spark.DataSource") .option("query", cypher) val exc = assertThrows( classOf[ClientException], () => { df.load() .toDF() .select("number").rdd.map(_.getLong(0)).collect().toList } ) assertTrue(exc.getMessage.contains("The transaction has been terminated")) } @Test def sparkConnectorFailsWithTransactionTimeoutWhenSetOnDatasourceLevel(): Unit = { val newConf = conf.clone().set("neo4j.url", server.getBoltUrl) .set("neo4j.authentication.basic.username", "neo4j") .set("neo4j.authentication.basic.password", server.getAdminPassword) val session = SparkSession.builder.config(newConf).getOrCreate() val cypher = "UNWIND range(1, 20) AS i " + "CALL apoc.util.sleep(1000) " + "RETURN i as number" val df = session.read.format("org.neo4j.spark.DataSource") .option("query", cypher) .option("db.transaction.timeout", "1000") val exc = assertThrows( classOf[ClientException], () => { df.load() .toDF() .select("number").rdd.map(_.getLong(0)).collect().toList } ) assertTrue(exc.getMessage.contains("The transaction has been terminated")) } @Test def sparkConnectorExtendsDefaultTimeout(): Unit = { val cypher = "UNWIND range(1, 6) AS i " + "CALL apoc.util.sleep(2000) " + "RETURN i as number" val df = ss.read.format("org.neo4j.spark.DataSource") .option("url", NEO4J_LOW_TX_TIMEOUT.getBoltUrl) .option("authentication.basic.username", "neo4j") .option("authentication.basic.password", NEO4J_LOW_TX_TIMEOUT.getAdminPassword) .option("db.transaction.timeout", "15000") .option("query", cypher) .load() .toDF() val results = df.select("number").rdd.map(_.getLong(0)).collect().toList val expected = Range.inclusive(1, 6).map(_.toLong).toList assertEquals(expected, results) } } object TransactionTimeoutIT { private val NEO4J_LOW_TX_TIMEOUT = new Neo4jContainerExtension { withNeo4jConfig("dbms.security.auth_enabled", "false") withEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") withEnv("NEO4JLABS_PLUGINS", "[\"apoc\"]") withEnv("NEO4J_db_temporal_timezone", TimeZone.getDefault.getID) withNeo4jConfig("db.transaction.timeout", "10s") withDatabases(Seq("db1", "db2")) } @BeforeClass def setUp(): Unit = { NEO4J_LOW_TX_TIMEOUT.start() } @AfterClass def tearDown() = { TestUtil.closeSafely(NEO4J_LOW_TX_TIMEOUT) } } ================================================ FILE: test-support/pom.xml ================================================ 4.0.0 org.neo4j neo4j-connector-apache-spark_parent 5.4.3-SNAPSHOT neo4j-connector-apache-spark_test-support jar neo4j-connector-apache-spark-test-support Test Utilities for Neo4j Connector for Apache Spark using the binary Bolt Driver true 2.0.9 org.objenesis objenesis 3.5 junit junit org.hamcrest hamcrest org.neo4j caniuse-core org.neo4j caniuse-neo4j-detection org.neo4j.driver neo4j-java-driver-slim org.powermock powermock-api-mockito2 ${powermock.version} org.powermock powermock-module-junit4 ${powermock.version} org.scalatest scalatest_${scala.binary.version} org.scalatestplus junit-4-13_${scala.binary.version} org.testcontainers testcontainers org.jetbrains annotations org.testcontainers testcontainers-neo4j org.apache.spark spark-core_${scala.binary.version} provided org.apache.spark spark-sql_${scala.binary.version} provided net.alchim31.maven scala-maven-plugin org.apache.maven.plugins maven-failsafe-plugin org.apache.maven.plugins maven-surefire-plugin org.apache.maven.plugins maven-deploy-plugin true ================================================ FILE: test-support/src/main/java/org/neo4j/spark/Assert.java ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark;// // Source code recreated from a .class file by IntelliJ IDEA // (powered by Fernflower decompiler) // import java.util.Arrays; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; import org.hamcrest.Description; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; import org.hamcrest.StringDescription; import org.hamcrest.core.StringContains; public final class Assert { private Assert() { } public interface ThrowingSupplier { T get() throws E; static ThrowingSupplier throwingSupplier(final Supplier supplier) { return new ThrowingSupplier() { public TYPE get() { return supplier.get(); } public String toString() { return supplier.toString(); } }; } } public interface ThrowingAction { void apply() throws E; static ThrowingAction noop() { return () -> { }; } } public static void assertException(ThrowingAction f, Class typeOfException) { assertException(f, typeOfException, (String) null); } public static void assertException(ThrowingAction f, Class typeOfException, String partOfErrorMessage) { try { f.apply(); org.junit.Assert.fail("Expected exception of type " + typeOfException + ", but no exception was thrown"); } catch (Exception var4) { if (typeOfException.isInstance(var4)) { if (partOfErrorMessage != null) { MatcherAssert.assertThat(var4.getMessage(), StringContains.containsString(partOfErrorMessage)); } } else { org.junit.Assert.fail("Got unexpected exception " + var4.getClass() + "\nExpected: " + typeOfException); } } } public static void assertEventually(ThrowingSupplier actual, Matcher matcher, long timeout, TimeUnit timeUnit) throws E, InterruptedException { assertEventually((ignored) -> { return ""; }, actual, matcher, timeout, timeUnit); } public static void assertEventually(String reason, ThrowingSupplier actual, Matcher matcher, long timeout, TimeUnit timeUnit) throws E, InterruptedException { assertEventually((ignored) -> { return reason; }, actual, matcher, timeout, timeUnit); } public static void assertEventually(Function reason, ThrowingSupplier actual, Matcher matcher, long timeout, TimeUnit timeUnit) throws E, InterruptedException { long endTimeMillis = System.currentTimeMillis() + timeUnit.toMillis(timeout); while (true) { long sampleTime = System.currentTimeMillis(); T last = actual.get(); boolean matched = matcher.matches(last); if (matched || sampleTime > endTimeMillis) { if (!matched) { Description description = new StringDescription(); description.appendText((String) reason.apply(last)).appendText("\nExpected: ").appendDescriptionOf(matcher).appendText("\n but: "); matcher.describeMismatch(last, description); throw new AssertionError("Timeout hit (" + timeout + " " + timeUnit.toString().toLowerCase() + ") while waiting for condition to match: " + description.toString()); } else { return; } } Thread.sleep(100L); } } private static AssertionError newAssertionError(String message, Object expected, Object actual) { return new AssertionError((message != null && !message.isEmpty() ? message + "\n" : "") + "Expected: " + prettyPrint(expected) + ", actual: " + prettyPrint(actual)); } private static String prettyPrint(Object o) { if (o == null) { return "null"; } Class clazz = o.getClass(); if (clazz.isArray()) { if (clazz == byte[].class) { return Arrays.toString((byte[]) o); } else if (clazz == short[].class) { return Arrays.toString((short[]) o); } else if (clazz == int[].class) { return Arrays.toString((int[]) o); } else if (clazz == long[].class) { return Arrays.toString((long[]) o); } else if (clazz == float[].class) { return Arrays.toString((float[]) o); } else if (clazz == double[].class) { return Arrays.toString((double[]) o); } else if (clazz == char[].class) { return Arrays.toString((char[]) o); } else if (clazz == boolean[].class) { return Arrays.toString((boolean[]) o); } else { return Arrays.deepToString((Object[]) o); } } else { return String.valueOf(o); } } } ================================================ FILE: test-support/src/main/resources/simplelogger.properties ================================================ org.slf4j.simpleLogger.defaultLogLevel=error org.slf4j.simpleLogger.showDateTime=true org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd'T'HH:mm:ss.SSS ================================================ FILE: test-support/src/main/scala/org/neo4j/Closeables.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j object Closeables { def use[A <: AutoCloseable, B](resource: A)(code: A ⇒ B): B = try code(resource) finally resource.close() } ================================================ FILE: test-support/src/main/scala/org/neo4j/Neo4jContainerExtension.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j import org.neo4j.driver.AuthToken import org.neo4j.driver.AuthTokens import org.neo4j.driver.GraphDatabase import org.neo4j.driver.SessionConfig import org.neo4j.spark.TestUtil import org.rnorth.ducttape.unreliables.Unreliables import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy import org.testcontainers.containers.wait.strategy.WaitAllStrategy import org.testcontainers.neo4j.Neo4jContainer import java.time.Duration import java.util.concurrent.Callable import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.io.Source class DatabasesWaitStrategy(private val auth: AuthToken) extends AbstractWaitStrategy { private var databases = Seq.empty[String] def forDatabases(dbs: Seq[String]): DatabasesWaitStrategy = { databases ++= dbs this } override def waitUntilReady(): Unit = { val boltUrl = s"bolt://${waitStrategyTarget.getHost}:${waitStrategyTarget.getMappedPort(7687)}" val driver = GraphDatabase.driver(boltUrl, auth) try { Unreliables.retryUntilSuccess( startupTimeout.getSeconds.toInt, TimeUnit.SECONDS, new Callable[Unit] { override def call(): Unit = { val session = driver.session(SessionConfig.forDatabase("system")) try { databases.foreach { db => session.writeTransaction(tx => tx.run(s"CREATE DATABASE $db IF NOT EXISTS WAIT 30 SECONDS").consume()) val status = session.readTransaction(tx => { tx.run(s"SHOW DATABASE $db YIELD currentStatus").single().get("currentStatus").asString() }) if (status != "online") { throw new RuntimeException(s"Database $db is not online yet, current status: $status") } } } finally { session.close() } } } ) } finally { driver.close() } } } // docker pull neo4j/neo4j-experimental:4.0.0-rc01-enterprise class Neo4jContainerExtension extends Neo4jContainer( TestUtil.neo4jImage() ) { private var databases: Seq[String] = Seq.empty private var fixture: Set[(String, String)] = Set.empty def withDatabases(dbs: Seq[String]): Neo4jContainerExtension = { databases ++= dbs this } def withFixture(database: String, path: String): Neo4jContainerExtension = { fixture ++= Set((database, path)) this } private def createAuth(): AuthToken = if (getAdminPassword.nonEmpty) AuthTokens.basic("neo4j", getAdminPassword) else AuthTokens.none() override def start(): Unit = { if (databases.nonEmpty) { val waitAllStrategy = waitStrategy.asInstanceOf[WaitAllStrategy] waitAllStrategy.withStrategy( new DatabasesWaitStrategy(createAuth()).forDatabases(databases).withStartupTimeout(Duration.ofMinutes(2)) ) } addEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") super.start() if (fixture.nonEmpty) { val driver = GraphDatabase.driver(this.getBoltUrl, createAuth()) try { fixture.foreach(t => { val session = driver.session(SessionConfig.forDatabase(t._1)) try { val lines = Source.fromResource(t._2) .mkString("\n") .split(";") lines.foreach(line => session.run(line)) } finally { TestUtil.closeSafely(session) } }) } finally { TestUtil.closeSafely(driver) } } } } ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/RowUtil.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.sql.Row object RowUtil { def getByName[T](row: Row, name: String): T = row.getAs[T](row.fieldIndex(name)) } ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/SparkConnectorScalaBaseTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.junit._ import org.junit.rules.TestName import org.neo4j.Closeables.use import org.scalatestplus.junit.AssertionsForJUnit import scala.annotation.meta.getter object SparkConnectorScalaBaseTSE { private var startedFromSuite = true @BeforeClass def setUpContainer() = { if (!SparkConnectorScalaSuiteIT.server.isRunning) { startedFromSuite = false SparkConnectorScalaSuiteIT.setUpContainer() } } @AfterClass def tearDownContainer() = { if (!startedFromSuite) { SparkConnectorScalaSuiteIT.tearDownContainer() } } } class SparkConnectorScalaBaseTSE extends AssertionsForJUnit { val conf: SparkConf = SparkConnectorScalaSuiteIT.conf val ss: SparkSession = SparkConnectorScalaSuiteIT.ss @(Rule @getter) val testName: TestName = new TestName @Before def before(): Unit = { use(SparkConnectorScalaSuiteIT.session("system")) { session => session .run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds").consume() } } @After def after(): Unit = { ss.catalog.listTables() .collect() .foreach(t => ss.catalog.dropTempView(t.name)) ss.catalog.listTables() .collect() .foreach(t => ss.catalog.dropGlobalTempView(t.name)) } } ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/SparkConnectorScalaBaseWithApocTSE.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.hamcrest.Matchers import org.junit._ import org.junit.rules.TestName import org.neo4j.Closeables.use import org.neo4j.driver.Transaction import org.neo4j.driver.TransactionWork import org.neo4j.driver.summary.ResultSummary import org.neo4j.spark import java.util.concurrent.TimeUnit import scala.annotation.meta.getter object SparkConnectorScalaBaseWithApocTSE { private var startedFromSuite = true @BeforeClass def setUpContainer() = { if (!SparkConnectorScalaSuiteWithApocIT.server.isRunning) { startedFromSuite = false SparkConnectorScalaSuiteWithApocIT.setUpContainer() } } @AfterClass def tearDownContainer() = { if (!startedFromSuite) { SparkConnectorScalaSuiteWithApocIT.tearDownContainer() } } } class SparkConnectorScalaBaseWithApocTSE { val conf: SparkConf = SparkConnectorScalaSuiteWithApocIT.conf val ss: SparkSession = SparkConnectorScalaSuiteWithApocIT.ss @(Rule @getter) val testName: TestName = new TestName @Before def before() { use(SparkConnectorScalaSuiteWithApocIT.session("system")) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds") .consume() } } } ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/SparkConnectorScalaSuiteIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assume import org.junit.BeforeClass import org.neo4j.Neo4jContainerExtension import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDetector import org.neo4j.driver._ import java.io.File import java.nio.file.Files import java.util.TimeZone object SparkConnectorScalaSuiteIT { val server: Neo4jContainerExtension = new Neo4jContainerExtension { withNeo4jConfig("dbms.security.auth_enabled", "false") withEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") withEnv("NEO4J_db_temporal_timezone", TimeZone.getDefault.getID) withDatabases(Seq("db1", "db2")) } var conf: SparkConf = _ var ss: SparkSession = _ var driver: Driver = _ var tmpDir: File = _ var neo4j: Neo4j = _ @BeforeClass def setUpContainer(): Unit = { if (!server.isRunning) { try { server.start() } catch { case _: Throwable => // } Assume.assumeTrue("Neo4j container is not started", server.isRunning) tmpDir = Files.createTempDirectory("spark-warehouse").toFile tmpDir.deleteOnExit() conf = new SparkConf() .setAppName("neoTest") .setMaster("local[*]") .set("spark.driver.host", "127.0.0.1") .set("spark.sql.warehouse.dir", tmpDir.getAbsolutePath) ss = SparkSession.builder.config(conf).getOrCreate() driver = GraphDatabase.driver(server.getBoltUrl, AuthTokens.none()) neo4j = Neo4jDetector.INSTANCE.detect(driver) } } @AfterClass def tearDownContainer() = { TestUtil.closeSafely(driver) TestUtil.closeSafely(server) TestUtil.closeSafely(ss) } def session(database: String = ""): Session = { if (database.isEmpty) { driver.session() } else { driver.session(SessionConfig.forDatabase(database)) } } } class SparkConnectorScalaSuiteIT {} ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/SparkConnectorScalaSuiteWithApocIT.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assume import org.junit.BeforeClass import org.neo4j.Neo4jContainerExtension import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDetector import org.neo4j.driver._ import java.util.TimeZone object SparkConnectorScalaSuiteWithApocIT { val server: Neo4jContainerExtension = new Neo4jContainerExtension { withNeo4jConfig("dbms.security.auth_enabled", "false") withEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") withEnv("NEO4JLABS_PLUGINS", "[\"apoc\"]") withEnv("NEO4J_db_temporal_timezone", TimeZone.getDefault.getID) withDatabases(Seq("db1", "db2")) } var conf: SparkConf = _ var ss: SparkSession = _ var driver: Driver = _ var neo4j: Neo4j = _ @BeforeClass def setUpContainer(): Unit = { if (!server.isRunning) { try { server.start() } catch { case _: Throwable => // } Assume.assumeTrue("Neo4j container is not started", server.isRunning) conf = new SparkConf() .setAppName("neoTest") .setMaster("local[*]") .set("spark.driver.host", "127.0.0.1") ss = SparkSession.builder.config(conf).getOrCreate() driver = GraphDatabase.driver(server.getBoltUrl, AuthTokens.none()) neo4j = Neo4jDetector.INSTANCE.detect(driver) } Assume.assumeTrue("Neo4j Preview versions doesn't have APOC", TestUtil.hasApoc(session())) } @AfterClass def tearDownContainer() = { TestUtil.closeSafely(driver) TestUtil.closeSafely(server) TestUtil.closeSafely(ss) } def session(database: String = ""): Session = { if (database.isEmpty) { driver.session() } else { driver.session(SessionConfig.forDatabase(database)) } } } class SparkConnectorScalaSuiteWithApocIT {} ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/SparkConnectorScalaSuiteWithGdsBase.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assume import org.junit.Before import org.junit.BeforeClass import org.junit.Rule import org.junit.rules.TestName import org.neo4j.Closeables.use import org.neo4j.Neo4jContainerExtension import org.neo4j.caniuse.Neo4j import org.neo4j.caniuse.Neo4jDetector import org.neo4j.driver._ import java.util.TimeZone import scala.annotation.meta.getter object SparkConnectorScalaSuiteWithGdsBase { val server: Neo4jContainerExtension = new Neo4jContainerExtension { withNeo4jConfig("dbms.security.auth_enabled", "false") withEnv("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") withEnv("NEO4JLABS_PLUGINS", "[\"graph-data-science\"]") withEnv("NEO4J_db_temporal_timezone", TimeZone.getDefault.getID) withDatabases(Seq("db1", "db2")) } var conf: SparkConf = _ var ss: SparkSession = _ var driver: Driver = _ var neo4j: Neo4j = _ @BeforeClass def setUpContainer(): Unit = { if (!server.isRunning) { try { server.start() } catch { case _: Throwable => // } Assume.assumeTrue("Neo4j container is not started", server.isRunning) conf = new SparkConf() .setAppName("neoTest") .setMaster("local[*]") .set("spark.driver.host", "127.0.0.1") ss = SparkSession.builder.config(conf).getOrCreate() driver = GraphDatabase.driver(server.getBoltUrl, AuthTokens.none()) neo4j = Neo4jDetector.INSTANCE.detect(driver) } } @AfterClass def tearDownContainer(): Unit = { TestUtil.closeSafely(driver) TestUtil.closeSafely(server) TestUtil.closeSafely(ss) } def session(database: String = ""): Session = { if (database.isEmpty) { driver.session() } else { driver.session(SessionConfig.forDatabase(database)) } } } class SparkConnectorScalaSuiteWithGdsBase { val conf: SparkConf = SparkConnectorScalaSuiteWithGdsBase.conf val ss: SparkSession = SparkConnectorScalaSuiteWithGdsBase.ss @(Rule @getter) val testName: TestName = new TestName @Before def before(): Unit = { use(SparkConnectorScalaSuiteWithGdsBase.session("system")) { session => session.run("CREATE OR REPLACE DATABASE neo4j WAIT 30 seconds") .consume() } } } ================================================ FILE: test-support/src/main/scala/org/neo4j/spark/TestUtil.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.neo4j.driver.Session import org.neo4j.driver.Transaction import org.slf4j.Logger import org.testcontainers.utility.DockerImageName import java.util.Properties case class Version(major: Int, minor: Int, patch: Int) { override def toString: String = s"${major}.${minor}.${patch}" } object Version { implicit val ordering: Ordering[Version] = Ordering.by(v => (v.major, v.minor, v.patch)) def parse(version: String): Version = { val fields = version.split("\\.") .map(field => { val additionalInfoIndex = field.indexOf('-') if (additionalInfoIndex.equals(-1)) field else field.substring(0, additionalInfoIndex) }) .map(_.toInt) .toList Version(fields.head, fields(1), fields(2)) } } object Versions { val GDS_2_4: Version = Version(2, 4, 0) val GDS_2_5: Version = Version(2, 5, 0) val NEO4J_4_4: Version = Version(4, 4, 0) val NEO4J_5: Version = Version(5, 0, 0) val NEO4J_5_13: Version = Version(5, 13, 0) } object TestUtil { def neo4jImage(): DockerImageName = { val image = Option(System.getenv("NEO4J_TEST_IMAGE")) .map(_.trim) .filter(_.nonEmpty) // avoids Java 11-only isBlank .getOrElse(throw new IllegalArgumentException("NEO4J_TEST_IMAGE environment variable is not defined!")) DockerImageName.parse(image).asCompatibleSubstituteFor("neo4j") } def gdsVersion(session: Session): Version = { Version.parse(session.run( "CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsVersion' RETURN value" ).single().get(0).asString()) } def neo4jVersion(session: Session): Version = { Version.parse(session.run( "CALL dbms.components() YIELD name, versions WHERE name = 'Neo4j Kernel' RETURN versions[0]" ).single().get(0).asString()) } def hasApoc(session: Session): Boolean = { val result = session.run( "SHOW PROCEDURES YIELD name RETURN any(x IN collect(name) WHERE x STARTS WITH 'apoc.') AS hasApoc" ) result.single().get("hasApoc").asBoolean() } def closeSafely(autoCloseable: AutoCloseable, logger: Logger = null): Unit = { try { autoCloseable match { case s: Session => if (s.isOpen) s.close() case t: Transaction => if (t.isOpen) t.close() case null => () case _ => autoCloseable.close() } } catch { case t: Throwable => if (logger != null) { t.printStackTrace() logger.warn(s"Cannot close ${autoCloseable.getClass.getSimpleName} because of the following exception:", t) } } } } ================================================ FILE: test-support/src/test/scala/org/neo4j/spark/VersionTest.scala ================================================ /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.neo4j.spark import org.junit.Assert.assertEquals import org.junit.Test class VersionTest { @Test def parses_versions(): Unit = { assertEquals(Version(5, 26, 399), Version.parse("5.26.399")) assertEquals(Version(2025, 11, 0), Version.parse("2025.11.0-41865")) } }